sgsPy
structurally guided sampling
Loading...
Searching...
No Matches
raster.h
Go to the documentation of this file.
1/******************************************************************************
2 *
3 * Project: sgs
4 * Purpose: GDALDataset wrapper for raster operations
5 * Author: Joseph Meyer
6 * Date: June, 2025
7 *
8 ******************************************************************************/
9
14
15#pragma once
16
17#include <filesystem>
18#include <iostream>
19
20#include <gdal_priv.h>
21#include <pybind11/pybind11.h>
22#include <pybind11/stl.h>
23
24#include <utils/helper.h>
25
26//used as cutoff for max band allowed in memory
27#define GIGABYTE 1073741824
28
29namespace sgs {
30namespace raster {
31
32namespace py = pybind11;
33using namespace pybind11::literals;
34
58 private:
59 GDALDatasetUniquePtr p_dataset;
60
61 std::vector<void *> rasterBandPointers;
62 std::vector<bool> rasterBandRead;
63
64 std::vector<void *> displayRasterBandPointers;
65 std::vector<bool> displayRasterBandRead;
66 int displayRasterWidth = -1;
67 int displayRasterHeight = -1;
68
69 double geotransform[6];
70 std::string crs = "";
71 std::string proj = "";
72
73 std::string tempDir = "";
74
75 bool destroyed = false;
76 bool externalRasterData = false;
77
97 void readRasterBand(int width, int height, int band) {
98 GDALDataType type = this->getRasterBandType(band);
99 size_t size = this->getRasterBandTypeSize(band);
100 size_t max = std::numeric_limits<size_t>::max();
101
102 //perform size checks
103 if (max / size < static_cast<size_t>(width) ||
104 max / (size * static_cast<size_t>(width)) < static_cast<size_t>(height)) {
105 throw std::runtime_error("raster too large to fit in memory.");
106 }
107
108 if (static_cast<size_t>(height) * static_cast<size_t>(width) * size > GIGABYTE) {
109 throw std::runtime_error("sgs does not allow allocation of a raster into memory for direct pixel access purposes if it would be larger than 1 gigabyte.");
110 }
111
112 //allocate data
113 void *p_data = VSIMalloc3(height, width, size);
114
115 //perform raster read on current band
116 CPLErr err = this->p_dataset->GetRasterBand(band + 1)->RasterIO(
117 GF_Read, //GDALRWFlag eRWFlag
118 0, //int nXOff
119 0, //int nYOff
120 this->getWidth(), //int nXSize
121 this->getHeight(), //int nYSize
122 p_data, //void *pData
123 width, //int nBufXSize
124 height, //int nBufYSize
125 type, //GDALDataType eBufType
126 0, //int nPixelSpace
127 0 //int nLineSpace
128 );
129 if (err) {
130 throw std::runtime_error("error reading raster band from dataset.");
131 }
132
133 //update dislpay information as required
134 if (width != this->getWidth() || height != this->getHeight()) {
135 this->displayRasterBandRead[band] = true;
136 this->displayRasterBandPointers[band] = p_data;
137 }
138 else {
139 this->rasterBandRead[band] = true;
140 this->rasterBandPointers[band] = p_data;
141 }
142 }
143
155 template <typename T>
156 py::buffer getBuffer(size_t size, void *p_buffer, int width, int height) {
157 //see https://pybind11.readthedocs.io/en/stable/advanced/pycpp/numpy.html#memory-view
158 return py::memoryview::from_buffer(
159 (T*)p_buffer, //buffer
160 {height, width}, //shape
161 {size * width, size} //stride
162 );
163 }
164
172 void createFromDataset(GDALDataset *p_dataset) {
173 this->p_dataset = GDALDatasetUniquePtr(p_dataset);
174
175 //geotransform
176 CPLErr cplerr = this->p_dataset->GetGeoTransform(this->geotransform);
177 if (cplerr) {
178 throw std::runtime_error("error getting geotransform from dataset.");
179 }
180
181 //crs
182 if(std::string(this->p_dataset->GetProjectionRef()).length() == 0) {
183 throw std::runtime_error("raster dataset does not have a projection definition.");
184 }
185 this->crs = std::string(OGRSpatialReference(this->p_dataset->GetProjectionRef()).GetName());
186
187 //proj
188 char *p_proj;
189 OGRErr ogrerr = OGRSpatialReference(this->p_dataset->GetProjectionRef()).exportToPrettyWkt(&p_proj);
190 if (ogrerr) {
191 throw std::runtime_error("error getting projection as WKT from dataset.");
192 }
193 this->proj = std::string(p_proj);
194 CPLFree(p_proj);
195
196 //initialize (but don't read) raster band pointers
197 this->rasterBandPointers = std::vector<void *>(this->getBandCount(), nullptr);
198 this->rasterBandRead = std::vector<bool>(this->getBandCount(), false);
199 this->displayRasterBandPointers = std::vector<void *>(this->getBandCount(), nullptr);
200 this->displayRasterBandRead = std::vector<bool>(this->getBandCount(), false);
201 }
202
203 public:
215 GDALRasterWrapper(std::string filename, std::string projDBPath) {
216 //set proj.db search path to search for the proj.db file which is included in sgs package
217 char **paths = nullptr;
218 paths = CSLAddString(paths, projDBPath.c_str());
219 OSRSetPROJSearchPaths(paths);
220 CSLDestroy(paths);
221
222 //must register drivers before trying to open a dataset
223 GDALAllRegister();
224
225 //dataset
226 GDALDataset *p_dataset = GDALDataset::FromHandle(GDALOpen(filename.c_str(), GA_ReadOnly));
227 if (!p_dataset) {
228 throw std::runtime_error("file given does not result in a valid dataset, check to ensure file path is accurate.");
229 }
230
231 this->createFromDataset(p_dataset);
232 }
233
244 GDALRasterWrapper(GDALDataset *p_dataset, std::vector<void *> bands) {
245 this->createFromDataset(p_dataset);
246 this->rasterBandPointers = bands;
247 this->rasterBandRead = std::vector<bool>(bands.size(), true);
248 }
249
256 GDALRasterWrapper(GDALDataset *p_dataset) {
257 this->createFromDataset(p_dataset);
258 }
259
274 GDALRasterWrapper(py::buffer buffer, std::vector<double> geotransform, std::string projection, std::vector<double> nanVals, std::vector<std::string> names, std::string projDBPath) {
275 //set proj.db search path to search for the proj.db file which is included in sgs package
276 char **paths = nullptr;
277 paths = CSLAddString(paths, projDBPath.c_str());
278 OSRSetPROJSearchPaths(paths);
279 CSLDestroy(paths);
280
281 py::buffer_info info = buffer.request();
282
283 //get height width and band count from pybuffer
284 int width, height;
285 size_t bandCount, bandSize;
286 if (info.ndim == 3) {
287 bandCount = info.shape[0];
288 height = info.shape[1];
289 width = info.shape[2];
290 bandSize = info.strides[0];
291 }
292 else if (info.ndim == 2) {
293 bandCount = 1;
294 height = info.shape[0];
295 width = info.shape[1];
296 }
297 else {
298 throw std::runtime_error("dimension of numpy array must be 2 or 3");
299 }
300
301 if (bandCount != names.size()) {
302 throw std::runtime_error("band names array does not have the same number of bands as the py buffer");
303 }
304
305 //get data type from pybuffer
306 GDALDataType type;
307 size_t size;
308 if (info.format == py::format_descriptor<int8_t>::format()) {
309 type = GDT_Int8;
310 size = sizeof(int8_t);
311 }
312 else if (info.format == py::format_descriptor<int16_t>::format()) {
313 type = GDT_Int16;
314 size = sizeof(int16_t);
315 }
316 else if (info.format == py::format_descriptor<uint16_t>::format()) {
317 type = GDT_UInt16;
318 size = sizeof(uint16_t);
319 }
320 else if (info.format == py::format_descriptor<int32_t>::format()) {
321 type = GDT_Int32;
322 size = sizeof(int32_t);
323 }
324 else if (info.format == py::format_descriptor<uint32_t>::format()) {
325 type = GDT_UInt32;
326 size = sizeof(uint32_t);
327 }
328 else if (info.format == py::format_descriptor<float>::format()) {
329 type = GDT_Float32;
330 size = sizeof(float);
331 }
332 else if (info.format == py::format_descriptor<double>::format()) {
333 type = GDT_Float64;
334 size = sizeof(double);
335 }
336 else {
337 throw std::runtime_error("data type of array must be one of int8, int16, uint16, int32, uint32, float32, or float64.");
338 }
339
340 GDALAllRegister();
341 GDALDataset *p_dataset = helper::createVirtualDataset("MEM", width, height, geotransform.data(), projection);
342 std::vector<void *> bands(bandCount);
343 for (size_t i = 0; i < bandCount; i++) {
345 band.p_buffer = (void *)((size_t)info.ptr + (i * bandSize));
346 band.type = type;
347 band.size = size;
348 band.nan = nanVals[i];
349 band.name = names[i];
350 helper::addBandToMEMDataset(p_dataset, band);
351 bands[i] = band.p_buffer;
352
353 }
354
355 this->createFromDataset(p_dataset);
356 this->rasterBandPointers = bands;
357 this->rasterBandRead = std::vector<bool>(bandCount, true);
358 this->externalRasterData = true;
359 }
360
366 if (destroyed) {
367 return;
368 }
369
370 for (int i = 0; i < this->getBandCount(); i++) {
371 //if the raster data is coming from a numpy array (this->externalRasterData true), then
372 //the memory will be cleaned up by Pythons garbage collector
373 if (this->rasterBandRead[i] && !this->externalRasterData) {
374 CPLFree(this->rasterBandPointers[i]);
375 }
376
377 if (this->displayRasterBandRead[i]) {
378 CPLFree(this->displayRasterBandPointers[i]);
379 }
380 }
381
382 GDALClose(GDALDataset::ToHandle(this->p_dataset.release()));
383
384 if (this->tempDir != "") {
385 std::filesystem::path temp = this->tempDir;
386 std::filesystem::remove_all(temp);
387 }
388 }
389
401 void close(void) {
402 for (int i = 0; i < this->getBandCount(); i++) {
403 //if the raster data is coming from a numpy array (this->externalRasterData true), then
404 //the memory will be cleaned up by Pythons garbage collector
405 if (this->rasterBandRead[i]) {
406 CPLFree(this->rasterBandPointers[i]);
407 }
408
409 if (this->displayRasterBandRead[i]) {
410 CPLFree(this->displayRasterBandPointers[i]);
411 }
412 }
413
414 GDALClose(GDALDataset::ToHandle(this->p_dataset.release()));
415
416 if (this->tempDir != "") {
417 std::filesystem::path temp = this->tempDir;
418 std::filesystem::remove_all(temp);
419 }
420
421 destroyed = true;
422 }
423
429 GDALDataset *getDataset() {
430 return this->p_dataset.get();
431 }
432
438 std::string getDriver() {
439 return std::string(this->p_dataset->GetDriverName())
440 + "/"
441 + std::string(this->p_dataset->GetDriver()->GetMetadataItem(GDAL_DMD_LONGNAME));
442 }
443
449 std::string getFullProjectionInfo() {
450 return this->proj;
451 }
452
458 std::string getCRS(){
459 return this->crs;
460 }
461
467 int getWidth() {
468 return this->p_dataset->GetRasterXSize();
469 }
470
477 return this->p_dataset->GetRasterYSize();
478 }
479
486 return this->p_dataset->GetRasterCount();
487 }
488
495 double getXMax() {
496 int width = this->p_dataset->GetRasterXSize();
497 int height = this->p_dataset->GetRasterYSize();
498 return std::max(
499 this->geotransform[0],
500 this->geotransform[0] + this->geotransform[1] * width + this->geotransform[2] * height
501 );
502 }
503
510 double getXMin() {
511 int width = this->p_dataset->GetRasterXSize();
512 int height = this->p_dataset->GetRasterYSize();
513 return std::min(
514 this->geotransform[0],
515 this->geotransform[0] + this->geotransform[1] * width + this->geotransform[2] * height
516 );
517 }
518
525 double getYMax() {
526 int width = this->p_dataset->GetRasterXSize();
527 int height = this->p_dataset->GetRasterYSize();
528 return std::max(
529 this->geotransform[3],
530 this->geotransform[3] + this->geotransform[4] * width + this->geotransform[5] * height
531 );
532 }
533
540 double getYMin() {
541 int width = this->p_dataset->GetRasterXSize();
542 int height = this->p_dataset->GetRasterYSize();
543 return std::min(
544 this->geotransform[3],
545 this->geotransform[3] + this->geotransform[4] * width + this->geotransform[5] * height
546 );
547 }
548
555 double getPixelWidth() {
556 return std::abs(this->geotransform[5]);
557 }
558
565 double getPixelHeight() {
566 return std::abs(this->geotransform[1]);
567 }
568
575 std::vector<std::string> getBands(){
576 std::vector<std::string> retval;
577
578 for( auto&& p_band : this->p_dataset->GetBands() ){
579 retval.push_back(p_band->GetDescription());
580 }
581
582 return retval;
583 }
584
590 double *getGeotransform() {
591 return this->geotransform;
592 }
593
599 double getBandNoDataValue(int band){
600 GDALRasterBand *p_band = this->getRasterBand(band);
601 return p_band->GetNoDataValue();
602 }
603
624 py::buffer getRasterBandAsMemView(int width, int height, int band) {
625 bool display = (width != this->getWidth() || height != this->getHeight());
626 void *p_buffer;
627 GDALDataType type = this->getRasterBandType(band);
628
629 //allocate raster if required
630 if (!display && !this->rasterBandRead[band]) {
631 this->readRasterBand(width, height, band);
632 }
633
634 //(re)allocate display raster if required
635 if (display) {
636 if (width != this->displayRasterWidth || height != this->displayRasterHeight) {
637 free(this->displayRasterBandPointers[band]);
638 this->displayRasterBandRead[band] = false;
639 }
640
641 if (this->displayRasterBandRead[band] == false) {
642 this->readRasterBand(width, height, band);
643 }
644 }
645
646 //get the (allocated) data buffer
647 p_buffer = (!display) ?
648 this->rasterBandPointers[band] :
649 this->displayRasterBandPointers[band];
650
651 switch(type) {
652 case GDT_Int8:
653 return getBuffer<int8_t>(sizeof(int8_t), p_buffer, width, height);
654 case GDT_UInt16:
655 return getBuffer<uint16_t>(sizeof(uint16_t), p_buffer, width, height);
656 case GDT_Int16:
657 return getBuffer<int16_t>(sizeof(int16_t), p_buffer, width, height);
658 case GDT_UInt32:
659 return getBuffer<uint32_t>(sizeof(uint32_t), p_buffer, width, height);
660 case GDT_Int32:
661 return getBuffer<int32_t>(sizeof(int32_t), p_buffer, width, height);
662 case GDT_Float32:
663 return getBuffer<float>(sizeof(float), p_buffer, width, height);
664 case GDT_Float64:
665 return getBuffer<double>(sizeof(double), p_buffer, width, height);
666 default:
667 throw std::runtime_error("raster pixel data type not supported.");
668 }
669 }
670
684 for (size_t i = 0; i < this->rasterBandPointers.size(); i++) {
685 rasterBandPointers[i] = nullptr;
686 rasterBandRead[i] = false;
687 }
688 }
689
696 GDALRasterBand *getRasterBand(int band) {
697 return this->p_dataset->GetRasterBand(band + 1);
698 }
699
706 void *getRasterBandBuffer(int band) {
707 if (!this->rasterBandRead[band]) {
708 this->readRasterBand(
709 this->getWidth(),
710 this->getHeight(),
711 band
712 );
713 }
714
715 return this->rasterBandPointers[band];
716 }
717
724 GDALDataType getRasterBandType(int band) {
725 GDALRasterBand *p_band = this->p_dataset->GetRasterBand(band + 1);
726 return p_band->GetRasterDataType();
727 }
728
735 size_t getRasterBandTypeSize(int band) {
736 switch (this->getRasterBandType(band)) {
737 case GDALDataType::GDT_Int8:
738 return 1;
739 case GDALDataType::GDT_UInt16:
740 case GDALDataType::GDT_Int16:
741 return 2;
742 case GDALDataType::GDT_UInt32:
743 case GDALDataType::GDT_Int32:
744 case GDALDataType::GDT_Float32:
745 return 4;
746 case GDALDataType::GDT_Float64:
747 return 8;
748 default:
749 std::string errorMsg = "GDALDataType of band " + std::to_string(band) + " not supported.";
750 throw std::runtime_error(errorMsg);
751 }
752 }
753
760 void write(std::string filename) {
761 std::filesystem::path filepath = filename;
762 std::string extension = filepath.extension().string();
763
764 if (extension != ".tif") {
765 throw std::runtime_error("write only supports .tif files right now");
766 }
767
768 GDALDriver *p_driver = GetGDALDriverManager()->GetDriverByName("GTiff");
769
770 GDALClose(p_driver->CreateCopy(filename.c_str(), this->p_dataset.get(), (int)false, nullptr, nullptr, nullptr));
771 }
772
779 void setTempDir(std::string tempDir) {
780 this->tempDir = tempDir;
781 }
782
788 std::string getTempDir() {
789 return this->tempDir;
790 }
791
799 std::vector<double> getGeotransformArray() {
800 std::vector<double> retval(6);
801 for (int i = 0; i < 6; i++) {
802 retval[i] = this->geotransform[i];
803 }
804
805 return retval;
806 }
807
816 std::string getDataType() {
817 GDALDataType type = this->getRasterBandType(0);
818 for (int i = 1; i < this->getBandCount(); i++) {
819 if (this->getRasterBandType(i) != type) {
820 return "";
821 }
822 }
823
824 switch (type) {
825 case GDT_Int8: return "int8";
826 case GDT_UInt16: return "uint16";
827 case GDT_Int16: return "int16";
828 case GDT_UInt32: return "uint32";
829 case GDT_Int32: return "int32";
830 case GDT_Float32: return "float32";
831 case GDT_Float64: return "float64";
832 default: throw std::runtime_error("GDAL data type not supported");
833 }
834 }
835};
836
837} //namespace raster
838} //namespace sgs
double getYMax()
Definition raster.h:525
void * getRasterBandBuffer(int band)
Definition raster.h:706
GDALDataset * getDataset()
Definition raster.h:429
~GDALRasterWrapper()
Definition raster.h:365
std::vector< double > getGeotransformArray()
Definition raster.h:799
GDALRasterWrapper(GDALDataset *p_dataset, std::vector< void * > bands)
Definition raster.h:244
std::string getDataType()
Definition raster.h:816
int getWidth()
Definition raster.h:467
void releaseBandBuffers(void)
Definition raster.h:683
GDALRasterWrapper(py::buffer buffer, std::vector< double > geotransform, std::string projection, std::vector< double > nanVals, std::vector< std::string > names, std::string projDBPath)
Definition raster.h:274
void write(std::string filename)
Definition raster.h:760
size_t getRasterBandTypeSize(int band)
Definition raster.h:735
std::string getDriver()
Definition raster.h:438
py::buffer getRasterBandAsMemView(int width, int height, int band)
Definition raster.h:624
double getBandNoDataValue(int band)
Definition raster.h:599
void setTempDir(std::string tempDir)
Definition raster.h:779
double * getGeotransform()
Definition raster.h:590
GDALDataType getRasterBandType(int band)
Definition raster.h:724
int getBandCount()
Definition raster.h:485
double getXMax()
Definition raster.h:495
std::string getCRS()
Definition raster.h:458
int getHeight()
Definition raster.h:476
GDALRasterWrapper(std::string filename, std::string projDBPath)
Definition raster.h:215
std::string getFullProjectionInfo()
Definition raster.h:449
GDALRasterBand * getRasterBand(int band)
Definition raster.h:696
void close(void)
Definition raster.h:401
GDALRasterWrapper(GDALDataset *p_dataset)
Definition raster.h:256
double getXMin()
Definition raster.h:510
std::string getTempDir()
Definition raster.h:788
double getPixelWidth()
Definition raster.h:555
double getYMin()
Definition raster.h:540
std::vector< std::string > getBands()
Definition raster.h:575
double getPixelHeight()
Definition raster.h:565
void addBandToMEMDataset(GDALDataset *p_dataset, RasterBandMetaData &band)
Definition helper.h:467
GDALDataset * createVirtualDataset(std::string driverName, int width, int height, double *geotransform, std::string projection)
Definition helper.h:294
Definition raster.h:30
Definition pca.h:23
#define GIGABYTE
Definition raster.h:27
Definition helper.h:87
void * p_buffer
Definition helper.h:89
size_t size
Definition helper.h:91
GDALDataType type
Definition helper.h:90
double nan
Definition helper.h:93
std::string name
Definition helper.h:92