diff --git a/include/proxy-wasm/context.h b/include/proxy-wasm/context.h index f88b4aaac..7f66cdbfa 100644 --- a/include/proxy-wasm/context.h +++ b/include/proxy-wasm/context.h @@ -353,6 +353,9 @@ class ContextBase : public RootInterface, WasmResult getSharedData(std::string_view key, std::pair *data) override; WasmResult setSharedData(std::string_view key, std::string_view value, uint32_t cas) override; + WasmResult getSharedDataKeys(std::vector *result) override; + WasmResult removeSharedDataKey(std::string_view key, uint32_t cas, + std::pair *result) override; // Shared Queue WasmResult registerSharedQueue(std::string_view queue_name, diff --git a/include/proxy-wasm/context_interface.h b/include/proxy-wasm/context_interface.h index 841dd0e16..85e251a0f 100644 --- a/include/proxy-wasm/context_interface.h +++ b/include/proxy-wasm/context_interface.h @@ -596,7 +596,7 @@ struct GeneralInterface { }; /** - * SharedDataInterface is for shaing data between VMs. In general the VMs may be on different + * SharedDataInterface is for sharing data between VMs. In general the VMs may be on different * threads. Keys can have any format, but good practice would use reverse DNS and namespacing * prefixes to avoid conflicts. */ @@ -621,6 +621,23 @@ struct SharedDataInterface { * @param data is a location to store the returned value. */ virtual WasmResult setSharedData(std::string_view key, std::string_view value, uint32_t cas) = 0; + + /** + * Return all the keys from the data shraed between VMs + * @param data is a location to store the returned value. + */ + virtual WasmResult getSharedDataKeys(std::vector *result) = 0; + + /** + * Removes the given key from the data shared between VMs. + * @param key is a proxy-wide key mapping to the shared data value. + * @param cas is a compare-and-swap value. If it is zero it is ignored, otherwise it must match + * @param cas is a location to store value, and cas number, associated with the removed key + * the cas associated with the value. + */ + virtual WasmResult + removeSharedDataKey(std::string_view key, uint32_t cas, + std::pair *result) = 0; }; // namespace proxy_wasm struct SharedQueueInterface { diff --git a/src/context.cc b/src/context.cc index 8a69aec5f..81633a7b3 100644 --- a/src/context.cc +++ b/src/context.cc @@ -192,6 +192,15 @@ WasmResult ContextBase::setSharedData(std::string_view key, std::string_view val return getGlobalSharedData().set(wasm_->vm_id(), key, value, cas); } +WasmResult ContextBase::getSharedDataKeys(std::vector *result) { + return getGlobalSharedData().keys(wasm_->vm_id(), result); +} + +WasmResult ContextBase::removeSharedDataKey(std::string_view key, uint32_t cas, + std::pair *result) { + return getGlobalSharedData().remove(wasm_->vm_id(), key, cas, result); +} + // Shared Queue WasmResult ContextBase::registerSharedQueue(std::string_view queue_name, diff --git a/src/shared_data.cc b/src/shared_data.cc index 73cbb1feb..d4306adae 100644 --- a/src/shared_data.cc +++ b/src/shared_data.cc @@ -56,6 +56,22 @@ WasmResult SharedData::get(std::string_view vm_id, const std::string_view key, return WasmResult::NotFound; } +WasmResult SharedData::keys(std::string_view vm_id, std::vector *result) { + result->clear(); + + std::lock_guard lock(mutex_); + auto map = data_.find(std::string(vm_id)); + if (map == data_.end()) { + return WasmResult::Ok; + } + + for (auto kv : map->second) { + result->push_back(kv.first); + } + + return WasmResult::Ok; +} + WasmResult SharedData::set(std::string_view vm_id, std::string_view key, std::string_view value, uint32_t cas) { std::lock_guard lock(mutex_); @@ -78,4 +94,29 @@ WasmResult SharedData::set(std::string_view vm_id, std::string_view key, std::st return WasmResult::Ok; } +WasmResult SharedData::remove(std::string_view vm_id, std::string_view key, uint32_t cas, + std::pair *result) { + std::lock_guard lock(mutex_); + std::unordered_map> *map; + auto map_it = data_.find(std::string(vm_id)); + if (map_it == data_.end()) { + return WasmResult::NotFound; + } else { + map = &map_it->second; + } + + auto it = map->find(std::string(key)); + if (it != map->end()) { + if (cas && cas != it->second.second) { + return WasmResult::CasMismatch; + } + if (result != nullptr) { + *result = it->second; + } + map->erase(it); + return WasmResult::Ok; + } + return WasmResult::NotFound; +} + } // namespace proxy_wasm diff --git a/src/shared_data.h b/src/shared_data.h index fec37aac8..cbc76fb12 100644 --- a/src/shared_data.h +++ b/src/shared_data.h @@ -25,8 +25,11 @@ class SharedData { SharedData(bool register_vm_id_callback = true); WasmResult get(std::string_view vm_id, const std::string_view key, std::pair *result); + WasmResult keys(std::string_view vm_id, std::vector *result); WasmResult set(std::string_view vm_id, std::string_view key, std::string_view value, uint32_t cas); + WasmResult remove(std::string_view vm_id, const std::string_view key, uint32_t cas, + std::pair *result); void deleteByVmId(std::string_view vm_id); private: diff --git a/test/shared_data_test.cc b/test/shared_data_test.cc index 0c3d2dec8..062625f8d 100644 --- a/test/shared_data_test.cc +++ b/test/shared_data_test.cc @@ -24,10 +24,24 @@ namespace proxy_wasm { TEST(SharedData, SingleThread) { SharedData shared_data(false); + std::string_view vm_id = "id"; + + // Validate we get an 'Ok' response when fetching keys before anything + // is initialized. + std::vector keys; + EXPECT_EQ(WasmResult::Ok, shared_data.keys(vm_id, &keys)); + EXPECT_EQ(0, keys.size()); + + // Validate that we clear the result set + std::vector nonEmptyKeys(2); + nonEmptyKeys[0] = "valueA"; + nonEmptyKeys[1] = "valueB"; + EXPECT_EQ(WasmResult::Ok, shared_data.keys(vm_id, &nonEmptyKeys)); + EXPECT_EQ(0, nonEmptyKeys.size()); + std::pair result; EXPECT_EQ(WasmResult::NotFound, shared_data.get("non-exist", "non-exists", &result)); - std::string_view vm_id = "id"; std::string_view key = "key"; std::string_view value = "1"; EXPECT_EQ(WasmResult::Ok, shared_data.set(vm_id, key, value, 0)); @@ -44,6 +58,31 @@ TEST(SharedData, SingleThread) { EXPECT_EQ(WasmResult::Ok, shared_data.get(vm_id, key, &result)); EXPECT_EQ(value, result.first); EXPECT_EQ(result.second, 3); + + EXPECT_EQ(WasmResult::Ok, shared_data.keys(vm_id, &keys)); + EXPECT_EQ(1, keys.size()); + EXPECT_EQ(key, keys[0]); + + keys.clear(); + EXPECT_EQ(WasmResult::CasMismatch, shared_data.remove(vm_id, key, 911, nullptr)); + EXPECT_EQ(WasmResult::Ok, shared_data.keys(vm_id, &keys)); + EXPECT_EQ(1, keys.size()); + + EXPECT_EQ(WasmResult::Ok, shared_data.remove(vm_id, key, 0, nullptr)); + EXPECT_EQ(WasmResult::NotFound, shared_data.get(vm_id, key, &result)); + + EXPECT_EQ(WasmResult::NotFound, shared_data.remove(vm_id, "non-existent_key", 0, nullptr)); + + EXPECT_EQ(WasmResult::Ok, shared_data.set(vm_id, key, value, 0)); + EXPECT_EQ(WasmResult::Ok, shared_data.set(vm_id, key, value, 0)); + EXPECT_EQ(WasmResult::Ok, shared_data.get(vm_id, key, &result)); + + uint32_t expectedCasValue = result.second; + + std::pair removeResult; + EXPECT_EQ(WasmResult::Ok, shared_data.remove(vm_id, key, 0, &removeResult)); + EXPECT_EQ(value, removeResult.first); + EXPECT_EQ(removeResult.second, expectedCasValue); } void incrementData(SharedData *shared_data, std::string_view vm_id, std::string_view key) {