blob: 36e692f4a5bfab18a93cb5facf2f1fb2da996eaa [file] [log] [blame] [edit]
#include "tools/ml/py_bridge.h"
#include <algorithm>
#include <iterator>
#include <cstdlib>
#include "cnpy.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <dirent.h>
#include <fnmatch.h> // For pattern matching
#include <unistd.h> // For unlink (file deletion)
#include <filesystem>
#include <sys/stat.h>
struct Var {
std::vector<size_t> shape;
std::vector<double> data;
};
struct DataFile {
std::map<std::string, Var> vars;
};
struct PyDataFile {
std::map<std::string, DataFile> files;
};
extern "C" void* py_datafile_open() {
struct PyDataFile* df = new struct PyDataFile();
return (void*)df;
}
static size_t get_data_length(std::vector<size_t> shape) {
int data_len = 1;
for (size_t v : shape) {
data_len *= v;
}
return data_len;
}
static void py_datafile_add_int(void** context, const char* file,
const char* var_name, float* var_data,
std::vector<size_t> shape) {
struct PyDataFile* df = (struct PyDataFile*)*context;
if (df == nullptr) {
printf("Creating py datafile\n");
df = new struct PyDataFile();
*context = (void*)df;
}
if (!df->files[file].vars.count(var_name)) {
df->files[file].vars[var_name].shape = shape;
} else {
const auto& old_shape = df->files[file].vars[var_name].shape;
if (old_shape != shape) {
std::cout << "Illegal shape for var:" << var_name << ", was ";
std::copy(old_shape.begin(), old_shape.end(), std::ostream_iterator<int>(std::cout, ", "));
std::cout << " now ";
std::copy(shape.begin(), shape.end(), std::ostream_iterator<int>(std::cout, ", "));
std::cout << std::endl;
std::exit(1);
}
}
auto& data = df->files[file].vars[var_name].data;
for (size_t i = 0; i < get_data_length(shape); i++) {
data.push_back(var_data[i]);
}
}
extern "C" void py_datafile_add_0d(void** context, const char* file,
const char* var_name, float* var_data) {
py_datafile_add_int(context, file, var_name, var_data, {});
}
extern "C" void py_datafile_add_1d(void** context, const char* file,
const char* var_name, float* var_data,
size_t shape0) {
py_datafile_add_int(context, file, var_name, var_data, {shape0});
}
extern "C" void py_datafile_add_2d(void** context, const char* file,
const char* var_name, float* var_data,
size_t shape0, size_t shape1) {
py_datafile_add_int(context, file, var_name, var_data, {shape0, shape1});
}
void py_datafile_fold(void** _to, void** _from) {
struct PyDataFile* to = (struct PyDataFile*)*_to;
struct PyDataFile* from = (struct PyDataFile*)*_from;
if (from == nullptr) {
std::cout << "\x1b[91mpy_data_fold: from is null\x1b[0m" << std::endl;
return;
}
if (to == nullptr) {
to = new struct PyDataFile();
*_to = (void*)to;
}
for (const auto& e0 : from->files) {
for (const auto& e1 : e0.second.vars) {
const auto& from_var = e1.second;
if (!to->files[e0.first].vars.count(e1.first)) {
to->files[e0.first].vars[e1.first].shape = from_var.shape;
} else {
const auto& to_shape = to->files[e0.first].vars[e1.first].shape;
if (to_shape != from_var.shape) {
std::cout << "Illegal shape for var:" << e1.first << ", was ";
std::copy(to_shape.begin(), to_shape.end(), std::ostream_iterator<int>(std::cout, ", "));
std::cout << " now ";
std::copy(from_var.shape.begin(), from_var.shape.end(), std::ostream_iterator<int>(std::cout, ", "));
std::cout << std::endl;
std::exit(1);
}
}
auto& to_var = to->files[e0.first].vars[e1.first];
to_var.data.insert(to_var.data.end(), from_var.data.begin(), from_var.data.end());
}
}
delete from;
*_from = nullptr;
}
extern "C" int py_datafile_delete_old_files(const char* _filename, size_t* count) {
DIR *dir;
struct dirent *entry;
char filepath[256];
char pattern[128];
const char* dirname;
const char* basename;
*count = 0;
char path_copy[strlen(_filename) + 1];
strcpy(path_copy, _filename);
char *last_slash = strrchr(path_copy, '/');
if (last_slash != NULL) {
*last_slash = '\0';
dirname = path_copy;
basename = last_slash + 1;
} else {
dirname = ".";
basename = path_copy;
}
if ((dir = opendir(dirname)) == NULL) {
return 1;
}
sprintf(pattern, "%s*.npz", basename);
int ret = 0;
while ((entry = readdir(dir))!= NULL) {
if (fnmatch(pattern, entry->d_name, 0) == 0) {
snprintf(filepath, sizeof(filepath), "%s/%s", dirname, entry->d_name);
if (unlink(filepath)) {
ret = 1;
} else {
(*count)++;
}
}
}
closedir(dir);
return ret;
}
bool dir_exists(const std::string& file_path) {
std::filesystem::path path = std::filesystem::path(file_path).parent_path();
if (path.empty())
return true;
struct stat info;
if (stat(path.c_str(), &info) != 0) {
return false; // Path does not exist or is not accessible
}
return (info.st_mode & S_IFDIR) != 0; // Check if it's a directory
}
extern "C" void py_datafile_close(void** context, const char* _filename) {
struct PyDataFile* df = (struct PyDataFile*)*context;
if (df == nullptr) {
std::cout << "\x1b[91mpy_data_close: df is null\x1b[0m" << std::endl;
return;
}
if (!dir_exists(std::string(_filename))) {
std::cout << "\x1b[91mpy_data_close: tring to create a file "
<< std::string(_filename) << " in a directory that doesn't "
<< "exist.\x1b[0m" << std::endl;
return;
}
for (const auto& e0 : df->files) {
const auto filename = std::string(_filename) + "_" + e0.first + ".npz";
for (const auto& e1 : e0.second.vars) {
const auto& var = e1.second;
const unsigned int data_length = get_data_length(var.shape);
std::vector<size_t> npz_shape;
npz_shape.push_back(var.data.size() / data_length);
npz_shape.insert(npz_shape.end(), var.shape.begin(), var.shape.end());
cnpy::npz_save(filename, e1.first, &var.data[0], npz_shape, "a");
}
}
delete df;
*context = nullptr;
}