forked from secretflow/yacl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
16,361 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
load("//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test") | ||
|
||
package(default_visibility = ["//visibility:public"]) | ||
|
||
yacl_cc_library( | ||
name = "TSet", | ||
srcs = ["TSet.cc"], | ||
hdrs = ["TSet.h"], | ||
deps = [ | ||
"//yacl/crypto/rand:rand", | ||
"//yacl/crypto/hmac:hmac_sha256", | ||
"//yacl/crypto/hash:ssl_hash" | ||
], | ||
) | ||
|
||
yacl_cc_test( | ||
name = "TSet_test", | ||
srcs = ["TSet_test.cc"], | ||
deps = [ | ||
":TSet", | ||
], | ||
) | ||
|
||
yacl_cc_library( | ||
name = "sse", | ||
srcs = ["sse.cc"], | ||
hdrs = ["sse.h"], | ||
deps = [ | ||
"//yacl/io/rw:csv_reader", | ||
"//yacl/io/stream:file_io", | ||
"//yacl/crypto/rand:rand", | ||
"//yacl/crypto/ecc/openssl:openssl", | ||
"//yacl/math/mpint:mpint", | ||
"//yacl/crypto/primitives/sse:TSet", | ||
], | ||
) | ||
|
||
yacl_cc_test( | ||
name = "sse_test", | ||
srcs = ["sse_test.cc"], | ||
deps = [ | ||
"//yacl/crypto/primitives/sse:sse", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
// Copyright 2024 Ant Group Co., Ltd. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "yacl/crypto/primitives/sse/TSet.h" | ||
|
||
namespace yacl::crypto::primitives::sse { | ||
|
||
TSet::TSet(int B, int S, int lambda, int n_lambda) | ||
: B_(B), S_(S), lambda_(lambda), n_lambda_(n_lambda) { | ||
initialize(); | ||
} | ||
|
||
bool TSet::areVectorsEqual(const std::vector<uint8_t>& vec1, | ||
const std::vector<uint8_t>& vec2) const { | ||
return std::equal(vec1.begin(), vec1.end(), vec2.begin(), vec2.end()); | ||
} | ||
|
||
std::vector<uint8_t> TSet::pack( | ||
const std::pair<std::vector<uint8_t>, std::string>& data) const { | ||
const auto& first = data.first; // 前部分(vector<uint8_t>) | ||
const auto& second = data.second; // 后部分(string) | ||
|
||
std::vector<uint8_t> result; | ||
|
||
// 1. 添加前部分(9 个字节的 vector<uint8_t>) | ||
result.insert(result.end(), first.begin(), first.end()); | ||
|
||
// 2. 添加后部分的长度(4 字节,记录字符串长度,假设最大长度 <= 2^32-1) | ||
uint32_t second_length = static_cast<uint32_t>(second.size()); | ||
uint8_t length_bytes[4]; | ||
std::memcpy(length_bytes, &second_length, 4); | ||
result.insert(result.end(), length_bytes, length_bytes + 4); | ||
|
||
// 3. 添加后部分的内容(string 的字节数据) | ||
result.insert(result.end(), second.begin(), second.end()); | ||
|
||
return result; | ||
} | ||
|
||
std::pair<std::vector<uint8_t>, std::string> TSet::unpack( | ||
const std::vector<uint8_t>& packed_data) const { | ||
// 1. 提取前部分(固定 9 个字节) | ||
std::vector<uint8_t> first(packed_data.begin(), packed_data.begin() + 9); | ||
|
||
// 2. 提取后部分的长度(从第 9 到第 13 字节) | ||
uint32_t second_length = 0; | ||
std::memcpy(&second_length, packed_data.data() + 9, 4); | ||
|
||
// 3. 提取后部分内容(字符串数据) | ||
std::string second(packed_data.begin() + 13, | ||
packed_data.begin() + 13 + second_length); | ||
|
||
return {first, second}; | ||
} | ||
|
||
std::string TSet::vectorToString(const std::vector<uint8_t>& vec) const { | ||
std::string result; | ||
for (auto& byte : vec) { | ||
result += std::to_string(static_cast<int>(byte)); | ||
} | ||
return result; | ||
} | ||
|
||
void TSet::initialize() { | ||
// 初始化 TSet 数组 | ||
TSet_.resize(B_, std::vector<Record>(S_)); | ||
|
||
// 初始化 Free 数组 | ||
Free_.resize(B_); | ||
for (int i = 0; i < B_; ++i) { | ||
for (int j = 1; j <= S_; ++j) { | ||
Free_[i].insert(j); | ||
} | ||
} | ||
|
||
// 初始化 TSet 中的每个 record | ||
for (int i = 0; i < B_; ++i) { | ||
for (int j = 0; j < S_; ++j) { | ||
TSet_[i][j].label.resize(lambda_ / 8, 0); // 初始化为长度为 λ 的位字符串 | ||
TSet_[i][j].value.resize(n_lambda_ / 8 + 1, | ||
0); // 初始化为长度为 n(λ) + 1 的位字符串 | ||
} | ||
} | ||
} | ||
|
||
std::pair<std::vector<std::vector<TSet::Record>>, std::string> TSet::TSetSetup( | ||
const std::unordered_map< | ||
std::string, std::vector<std::pair<std::vector<uint8_t>, std::string>>>& | ||
T, | ||
const std::vector<std::string>& keywords) { | ||
std::vector<std::vector<Record>> TSet; | ||
std::vector<std::set<int>> Free; | ||
|
||
restart: | ||
initialize(); | ||
|
||
std::vector<uint8_t> rand_bytes_Kt = yacl::crypto::RandBytes(32); | ||
std::string Kt = vectorToString(rand_bytes_Kt); | ||
yacl::crypto::HmacSha256 hmac_F_line_Tset(Kt); | ||
for (const auto& keyword : keywords) { | ||
hmac_F_line_Tset.Reset(); | ||
hmac_F_line_Tset.Update(keyword); | ||
auto mac = hmac_F_line_Tset.CumulativeMac(); | ||
std::string stag = vectorToString(mac); | ||
const auto& t = T.at(keyword); // 使用 at 方法访问元素 | ||
yacl::crypto::HmacSha256 hmac_F_Tset(stag); | ||
size_t i = 1; | ||
for (const auto& si : t) { | ||
hmac_F_Tset.Reset(); | ||
hmac_F_Tset.Update(std::to_string(i)); | ||
auto mac = hmac_F_Tset.CumulativeMac(); | ||
std::string mac_str = vectorToString(mac); | ||
|
||
yacl::crypto::Sm3Hash sm3; | ||
sm3.Reset(); | ||
std::vector<uint8_t> hash = sm3.Update(mac_str).CumulativeHash(); | ||
size_t hash_value = 0; | ||
for (size_t i = 0; i < hash.size(); ++i) { | ||
hash_value = (hash_value * 256 + hash[i]) % B_; | ||
} | ||
size_t b = (hash_value % B_); | ||
// 只取前 128 位(16 字节) | ||
std::vector<uint8_t> L(hash.begin(), hash.begin() + lambda_ / 8); | ||
yacl::crypto::Sha256Hash sha256; | ||
sha256.Reset(); | ||
std::vector<uint8_t> K = sha256.Update(mac_str).CumulativeHash(); | ||
if (Free_[b].empty()) { | ||
goto restart; | ||
} | ||
|
||
// 从 Free[b] 中随机选择一个元素 j,并删除 | ||
auto it = Free_[b].begin(); | ||
std::advance( | ||
it, yacl::crypto::RandU32() % Free_[b].size()); // 随机移动到某个位置 | ||
int j = *it; | ||
Free_[b].erase(j); | ||
|
||
j = (j - 1) % S_; | ||
TSet_[b][j].label = L; | ||
|
||
// 计算 (β|si) ⊕ K | ||
auto packed_si = pack(si); | ||
size_t beta = (i < t.size()) ? 1 : 0; | ||
std::vector<uint8_t> beta_si; | ||
beta_si.push_back(static_cast<uint8_t>(beta)); | ||
beta_si.insert(beta_si.end(), packed_si.begin(), packed_si.end()); | ||
std::vector<uint8_t> value_xor_k(beta_si.size()); | ||
for (size_t k = 0; k < beta_si.size(); ++k) { | ||
value_xor_k[k] = beta_si[k] ^ K[k % K.size()]; | ||
} | ||
TSet_[b][j].value = value_xor_k; | ||
i++; | ||
} | ||
} | ||
|
||
return {TSet_, Kt}; | ||
} | ||
|
||
std::vector<uint8_t> TSet::TSetGetTag(const std::string& Kt, | ||
const std::string& w) const { | ||
// std::string Kt = vectorToString(rand_bytes_Kt); | ||
yacl::crypto::HmacSha256 hmac_F_line_Tset(Kt); | ||
hmac_F_line_Tset.Reset(); | ||
hmac_F_line_Tset.Update(w); | ||
auto mac = hmac_F_line_Tset.CumulativeMac(); | ||
return mac; | ||
} | ||
|
||
std::vector<std::pair<std::vector<uint8_t>, std::string>> TSet::TSetRetrieve( | ||
const std::vector<std::vector<Record>>& TSet, | ||
const std::string& stag) const { | ||
yacl::crypto::HmacSha256 hmac_F_Tset(stag); | ||
|
||
std::vector<std::pair<std::vector<uint8_t>, std::string>> t; | ||
uint8_t beta = 1; | ||
size_t i = 1; | ||
|
||
while (beta == 1) { | ||
hmac_F_Tset.Reset(); | ||
hmac_F_Tset.Update(std::to_string(i)); | ||
auto mac = hmac_F_Tset.CumulativeMac(); | ||
std::string mac_str = vectorToString(mac); | ||
|
||
yacl::crypto::Sm3Hash sm3; | ||
sm3.Reset(); | ||
std::vector<uint8_t> hash = sm3.Update(mac_str).CumulativeHash(); | ||
size_t hash_value = 0; | ||
for (size_t i = 0; i < hash.size(); ++i) { | ||
hash_value = (hash_value * 256 + hash[i]) % B_; | ||
} | ||
size_t b = (hash_value % B_); | ||
// 只取前 128 位(16 字节) | ||
std::vector<uint8_t> L(hash.begin(), hash.begin() + lambda_ / 8); | ||
yacl::crypto::Sha256Hash sha256; | ||
sha256.Reset(); | ||
std::vector<uint8_t> K = sha256.Update(mac_str).CumulativeHash(); | ||
auto& B = TSet[b]; | ||
int j = 0; | ||
for (; j < S_; ++j) { | ||
if (areVectorsEqual(B[j].label, L)) { | ||
std::vector<uint8_t> v(B[j].value.size()); | ||
for (size_t k = 0; k < v.size(); ++k) { | ||
v[k] = B[j].value[k] ^ K[k % K.size()]; | ||
} | ||
// Let β be the first bit of v, and s the remaining n(λ) bits of v | ||
beta = v[0]; | ||
std::vector<uint8_t> s(v.begin() + 1, v.end()); | ||
auto unpacked_s = unpack(s); | ||
t.push_back(unpacked_s); | ||
} | ||
} | ||
++i; | ||
if (i > 100) { | ||
break; | ||
} | ||
} | ||
|
||
return t; | ||
} | ||
|
||
} // namespace yacl::crypto::primitives::sse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
// Copyright 2024 Ant Group Co., Ltd. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <unordered_map> | ||
#include <set> | ||
#include <utility> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "yacl/crypto/hash/ssl_hash.h" | ||
#include "yacl/crypto/hmac/hmac_sha256.h" | ||
#include "yacl/crypto/rand/rand.h" | ||
|
||
|
||
|
||
namespace yacl::crypto::primitives::sse { | ||
|
||
class TSet { | ||
public: | ||
struct Record { | ||
std::vector<uint8_t> label; // 存储长度为 λ 的位字符串 | ||
std::vector<uint8_t> value; // 存储长度为 n(λ) + 1 的位字符串 | ||
}; | ||
|
||
TSet(int B, int S, int lambda, int n_lambda); | ||
|
||
bool areVectorsEqual(const std::vector<uint8_t>& vec1, | ||
const std::vector<uint8_t>& vec2) const; | ||
|
||
std::vector<uint8_t> pack( | ||
const std::pair<std::vector<uint8_t>, std::string>& data) const; | ||
|
||
std::pair<std::vector<uint8_t>, std::string> unpack( | ||
const std::vector<uint8_t>& packed_data) const; | ||
|
||
std::string vectorToString(const std::vector<uint8_t>& vec) const; | ||
|
||
std::pair<std::vector<std::vector<Record>>, std::string> TSetSetup( | ||
const std::unordered_map< | ||
std::string, | ||
std::vector<std::pair<std::vector<uint8_t>, std::string>>>& T, | ||
const std::vector<std::string>& keywords); | ||
|
||
std::vector<uint8_t> TSetGetTag(const std::string& Kt, | ||
const std::string& w) const; | ||
|
||
std::vector<std::pair<std::vector<uint8_t>, std::string>> TSetRetrieve( | ||
const std::vector<std::vector<Record>>& TSet, | ||
const std::string& stag) const; | ||
|
||
// 公共接口函数,用于访问私有成员变量 | ||
const std::vector<std::vector<Record>>& getTSet() const { return TSet_; } | ||
const std::vector<std::set<int>>& getFree() const { return Free_; } | ||
|
||
private: | ||
void initialize(); | ||
|
||
int B_; | ||
int S_; | ||
int lambda_; | ||
int n_lambda_; | ||
std::vector<std::vector<Record>> TSet_; | ||
std::vector<std::set<int>> Free_; | ||
}; | ||
|
||
} // namespace yacl::crypto::primitives::sse |
Oops, something went wrong.