1
0
Эх сурвалжийг харах

refactor: 调整安全内存函数

SongZihuan 3 жил өмнө
parent
commit
2198b95abe

+ 0 - 6
include/tool/mem.h

@@ -8,10 +8,4 @@
 
 #include "mem.inline.h"
 
-#ifndef MEM_NOT_DEFINE
-#define free(p) (aFuntool::safeFree((p)))
-#define calloc(n, obj) (obj *)(aFuntool::safeCalloc(n, sizeof(obj)))
-#define calloc_size(n, size) (aFuntool::safeCalloc(n, size))
-#endif
-
 #endif  // AFUN_MEM_H

+ 8 - 7
include/tool/mem.inline.h

@@ -8,11 +8,12 @@
 
 /* 取代calloc函数 */
 namespace aFuntool {
-    template <typename T>
-    static void *safeFree(T *ptr) {if (ptr != nullptr) free((void *)ptr); return nullptr;}
+    template <typename T = void *>
+    static T *safeFree(T *ptr) {if (ptr != nullptr) free((void *)ptr); return nullptr;}
 
-    static void *safeCalloc(size_t n, size_t size){
-        void *re = calloc(n, size);
+    template <typename T = void *>
+    static T *safeCalloc(size_t n, size_t size){
+        T *re = (T *)calloc(n, size);
         if (re == nullptr) {
             if (SysLogger)
                 fatalErrorLog(SysLogger, EXIT_FAILURE, "The memory error");
@@ -22,9 +23,9 @@ namespace aFuntool {
         return re;
     }
 
-    template <typename T>
-    static void *safeCalloc(size_t n, T &t){
-        void *re = calloc(n, sizeof(decltype(*t)));  // 自动推断类型
+    template <typename T = void *>
+    static T *safeCalloc(size_t n = 1){
+        T *re = (T *)calloc(n, sizeof(T));  // 自动推断类型
         if (re == nullptr) {
             if (SysLogger)
                 fatalErrorLog(SysLogger, EXIT_FAILURE, "The memory error");

+ 2 - 2
src/tool/byte.cpp

@@ -60,7 +60,7 @@ namespace aFuntool {
             return true;
         }
 
-        str = calloc(len + 1, char);
+        str = safeCalloc<char>(len + 1);
         return fread(str, sizeof(char), len, file) == len;
     }
 
@@ -77,7 +77,7 @@ namespace aFuntool {
             return true;
         }
 
-        char *tmp = calloc(len + 1, char);
+        char *tmp = safeCalloc<char>(len + 1);
         size_t ret = fread(tmp, sizeof(char), len, file);
         str = tmp;
         free(tmp);

+ 0 - 2
src/tool/log.cpp

@@ -39,8 +39,6 @@
 #define getpid() (long)getpid()
 #endif
 
-#undef calloc
-
 namespace aFuntool {
     typedef struct LogNode LogNode;
     struct LogNode {  // 日志信息记录节点

+ 2 - 2
src/tool/md5.cpp

@@ -33,7 +33,7 @@ namespace aFuntool {
     };
 
     MD5_CTX *MD5Init(){
-        auto context = calloc(1, MD5_CTX);
+        auto context = safeCalloc<MD5_CTX>();
         context->count[0] = 0;
         context->count[1] = 0;
         context->state[0] = 0x67452301;
@@ -203,7 +203,7 @@ namespace aFuntool {
         if ((fd = fileOpen(path, "rb")) == nullptr)
             throw FileOpenException(path);
 
-        char *md5str = calloc(MD5_STRING, char);
+        char *md5str = safeCalloc<char>(MD5_STRING);
         MD5_CTX *md5 = MD5Init();
         while (true) {
             ret = fread(data, 1, READ_DATA_SIZE, fd);

+ 8 - 8
src/tool/stdio_.cpp

@@ -330,7 +330,7 @@ namespace aFuntool {
         if (tmp_len == 0)
             return 0;
 
-        auto tmp = calloc(tmp_len + 1, wchar_t);
+        auto tmp = safeCalloc<wchar_t>(tmp_len + 1);
         if (MultiByteToWideChar(from, 0, str, -1, tmp, tmp_len) == 0)
             return 0;
 
@@ -338,7 +338,7 @@ namespace aFuntool {
         if (dest_len == 0)
             return 0;
 
-        *dest = calloc(dest_len + 1, char);
+        *dest = safeCalloc<char>(dest_len + 1);
         int re = WideCharToMultiByte(to, 0, tmp, -1, *dest, dest_len, nullptr, nullptr);
 
         free(tmp);
@@ -353,7 +353,7 @@ namespace aFuntool {
         if (tmp_len == 0)
             return 0;
 
-        *dest = calloc(tmp_len + 1, wchar_t);
+        *dest = safeCalloc<wchar_t>(tmp_len + 1);
         return MultiByteToWideChar(from, 0, str, -1, *dest, tmp_len);
     }
 
@@ -365,21 +365,21 @@ namespace aFuntool {
         if (dest_len == 0)
             return 0;
 
-        *dest = calloc(dest_len + 1, char);
+        *dest = safeCalloc<char>(dest_len + 1);
         return WideCharToMultiByte(to, 0, str, -1, *dest, dest_len, nullptr, nullptr);
     }
 
     int fgets_stdin(char **dest, int len){
         int re = 0;
         if (!_isatty(_fileno(stdin))) {
-            *dest = calloc(len + 1, char);
+            *dest = safeCalloc<char>(len + 1);
             re = fgets(*dest, len, stdin) != nullptr;
             if (!re)
                 free(*dest);
             return re;
         }
 
-        char *wstr = calloc(len, char);
+        char *wstr = safeCalloc<char>(len);
         UINT code_page = GetConsoleCP();
         if (fgets_stdin_(wstr, len) != nullptr)  // 已经有互斥锁
             re = convertMultiByte(dest, wstr, code_page, CP_UTF8);
@@ -450,7 +450,7 @@ namespace aFuntool {
         if (buf_len == 0)
             buf_len = 1024;
         buf_len += 10;  // 预留更多位置
-        char *buf = calloc(buf_len, char);
+        char *buf = safeCalloc<char>(buf_len);
         size_t re = vsnprintf(buf, buf_len, format, ap);
         if (fputs_std_(buf, std) == EOF)
             re = 0;
@@ -470,7 +470,7 @@ namespace aFuntool {
     // 默认Linux平台均使用utf-8
     
     int fgets_stdin(char **dest, int len) {
-        *dest = calloc(len, char);
+        *dest = safeCalloc<char>(len);
         if (fgets(*dest, len, stdin) == nullptr)
             return 0;
         return 1;

+ 2 - 2
src/tool/string.cpp

@@ -11,8 +11,8 @@
 #define EQ_STR(str1, str2) (!strcmp((str1), (str2)))
 #define EQ_WSTR(wid1, wid2) (!wcscmp((wid1), (wid2)))
 
-#define NEW_STR(size) calloc((size) + 1, char)
-#define NEW_WSTR(size) calloc((size) + 1, wchar_t)
+#define NEW_STR(size) safeCalloc<char>((size) + 1)
+#define NEW_WSTR(size) safeCalloc<wchar_t>((size) + 1)
 
 #define STR_LEN(p) (((p) == NULL) ? 0 : strlen((p)))
 #define WSTR_LEN(p) (((p) == NULL) ? 0 : wcslen((p)))

+ 2 - 2
test/src/tool-mem.cpp

@@ -1,11 +1,11 @@
 #include "aFuntool.h"
 
 int main() {
-    int *p = calloc(1, int);
+    int *p = aFuntool::safeCalloc<int>();
     *p = 10;
     free(p);
 
-    p = calloc(1, int);
+    p = aFuntool::safeCalloc<int>(1);
     *p = 10;
     free(p);
     return 0;