diff --git a/include/proxy-wasm/context.h b/include/proxy-wasm/context.h
index f88b4aaac..7f66cdbfa 100644
--- a/include/proxy-wasm/context.h
+++ b/include/proxy-wasm/context.h
@@ -353,6 +353,9 @@ class ContextBase : public RootInterface,
WasmResult getSharedData(std::string_view key,
std::pair *data) override;
WasmResult setSharedData(std::string_view key, std::string_view value, uint32_t cas) override;
+ WasmResult getSharedDataKeys(std::vector *result) override;
+ WasmResult removeSharedDataKey(std::string_view key, uint32_t cas,
+ std::pair *result) override;
// Shared Queue
WasmResult registerSharedQueue(std::string_view queue_name,
diff --git a/include/proxy-wasm/context_interface.h b/include/proxy-wasm/context_interface.h
index 841dd0e16..85e251a0f 100644
--- a/include/proxy-wasm/context_interface.h
+++ b/include/proxy-wasm/context_interface.h
@@ -596,7 +596,7 @@ struct GeneralInterface {
};
/**
- * SharedDataInterface is for shaing data between VMs. In general the VMs may be on different
+ * SharedDataInterface is for sharing data between VMs. In general the VMs may be on different
* threads. Keys can have any format, but good practice would use reverse DNS and namespacing
* prefixes to avoid conflicts.
*/
@@ -621,6 +621,23 @@ struct SharedDataInterface {
* @param data is a location to store the returned value.
*/
virtual WasmResult setSharedData(std::string_view key, std::string_view value, uint32_t cas) = 0;
+
+ /**
+ * Return all the keys from the data shraed between VMs
+ * @param data is a location to store the returned value.
+ */
+ virtual WasmResult getSharedDataKeys(std::vector *result) = 0;
+
+ /**
+ * Removes the given key from the data shared between VMs.
+ * @param key is a proxy-wide key mapping to the shared data value.
+ * @param cas is a compare-and-swap value. If it is zero it is ignored, otherwise it must match
+ * @param cas is a location to store value, and cas number, associated with the removed key
+ * the cas associated with the value.
+ */
+ virtual WasmResult
+ removeSharedDataKey(std::string_view key, uint32_t cas,
+ std::pair *result) = 0;
}; // namespace proxy_wasm
struct SharedQueueInterface {
diff --git a/src/context.cc b/src/context.cc
index 8a69aec5f..81633a7b3 100644
--- a/src/context.cc
+++ b/src/context.cc
@@ -192,6 +192,15 @@ WasmResult ContextBase::setSharedData(std::string_view key, std::string_view val
return getGlobalSharedData().set(wasm_->vm_id(), key, value, cas);
}
+WasmResult ContextBase::getSharedDataKeys(std::vector *result) {
+ return getGlobalSharedData().keys(wasm_->vm_id(), result);
+}
+
+WasmResult ContextBase::removeSharedDataKey(std::string_view key, uint32_t cas,
+ std::pair *result) {
+ return getGlobalSharedData().remove(wasm_->vm_id(), key, cas, result);
+}
+
// Shared Queue
WasmResult ContextBase::registerSharedQueue(std::string_view queue_name,
diff --git a/src/shared_data.cc b/src/shared_data.cc
index 73cbb1feb..d4306adae 100644
--- a/src/shared_data.cc
+++ b/src/shared_data.cc
@@ -56,6 +56,22 @@ WasmResult SharedData::get(std::string_view vm_id, const std::string_view key,
return WasmResult::NotFound;
}
+WasmResult SharedData::keys(std::string_view vm_id, std::vector *result) {
+ result->clear();
+
+ std::lock_guard lock(mutex_);
+ auto map = data_.find(std::string(vm_id));
+ if (map == data_.end()) {
+ return WasmResult::Ok;
+ }
+
+ for (auto kv : map->second) {
+ result->push_back(kv.first);
+ }
+
+ return WasmResult::Ok;
+}
+
WasmResult SharedData::set(std::string_view vm_id, std::string_view key, std::string_view value,
uint32_t cas) {
std::lock_guard lock(mutex_);
@@ -78,4 +94,29 @@ WasmResult SharedData::set(std::string_view vm_id, std::string_view key, std::st
return WasmResult::Ok;
}
+WasmResult SharedData::remove(std::string_view vm_id, std::string_view key, uint32_t cas,
+ std::pair *result) {
+ std::lock_guard lock(mutex_);
+ std::unordered_map> *map;
+ auto map_it = data_.find(std::string(vm_id));
+ if (map_it == data_.end()) {
+ return WasmResult::NotFound;
+ } else {
+ map = &map_it->second;
+ }
+
+ auto it = map->find(std::string(key));
+ if (it != map->end()) {
+ if (cas && cas != it->second.second) {
+ return WasmResult::CasMismatch;
+ }
+ if (result != nullptr) {
+ *result = it->second;
+ }
+ map->erase(it);
+ return WasmResult::Ok;
+ }
+ return WasmResult::NotFound;
+}
+
} // namespace proxy_wasm
diff --git a/src/shared_data.h b/src/shared_data.h
index fec37aac8..cbc76fb12 100644
--- a/src/shared_data.h
+++ b/src/shared_data.h
@@ -25,8 +25,11 @@ class SharedData {
SharedData(bool register_vm_id_callback = true);
WasmResult get(std::string_view vm_id, const std::string_view key,
std::pair *result);
+ WasmResult keys(std::string_view vm_id, std::vector *result);
WasmResult set(std::string_view vm_id, std::string_view key, std::string_view value,
uint32_t cas);
+ WasmResult remove(std::string_view vm_id, const std::string_view key, uint32_t cas,
+ std::pair *result);
void deleteByVmId(std::string_view vm_id);
private:
diff --git a/test/shared_data_test.cc b/test/shared_data_test.cc
index 0c3d2dec8..062625f8d 100644
--- a/test/shared_data_test.cc
+++ b/test/shared_data_test.cc
@@ -24,10 +24,24 @@ namespace proxy_wasm {
TEST(SharedData, SingleThread) {
SharedData shared_data(false);
+ std::string_view vm_id = "id";
+
+ // Validate we get an 'Ok' response when fetching keys before anything
+ // is initialized.
+ std::vector keys;
+ EXPECT_EQ(WasmResult::Ok, shared_data.keys(vm_id, &keys));
+ EXPECT_EQ(0, keys.size());
+
+ // Validate that we clear the result set
+ std::vector nonEmptyKeys(2);
+ nonEmptyKeys[0] = "valueA";
+ nonEmptyKeys[1] = "valueB";
+ EXPECT_EQ(WasmResult::Ok, shared_data.keys(vm_id, &nonEmptyKeys));
+ EXPECT_EQ(0, nonEmptyKeys.size());
+
std::pair result;
EXPECT_EQ(WasmResult::NotFound, shared_data.get("non-exist", "non-exists", &result));
- std::string_view vm_id = "id";
std::string_view key = "key";
std::string_view value = "1";
EXPECT_EQ(WasmResult::Ok, shared_data.set(vm_id, key, value, 0));
@@ -44,6 +58,31 @@ TEST(SharedData, SingleThread) {
EXPECT_EQ(WasmResult::Ok, shared_data.get(vm_id, key, &result));
EXPECT_EQ(value, result.first);
EXPECT_EQ(result.second, 3);
+
+ EXPECT_EQ(WasmResult::Ok, shared_data.keys(vm_id, &keys));
+ EXPECT_EQ(1, keys.size());
+ EXPECT_EQ(key, keys[0]);
+
+ keys.clear();
+ EXPECT_EQ(WasmResult::CasMismatch, shared_data.remove(vm_id, key, 911, nullptr));
+ EXPECT_EQ(WasmResult::Ok, shared_data.keys(vm_id, &keys));
+ EXPECT_EQ(1, keys.size());
+
+ EXPECT_EQ(WasmResult::Ok, shared_data.remove(vm_id, key, 0, nullptr));
+ EXPECT_EQ(WasmResult::NotFound, shared_data.get(vm_id, key, &result));
+
+ EXPECT_EQ(WasmResult::NotFound, shared_data.remove(vm_id, "non-existent_key", 0, nullptr));
+
+ EXPECT_EQ(WasmResult::Ok, shared_data.set(vm_id, key, value, 0));
+ EXPECT_EQ(WasmResult::Ok, shared_data.set(vm_id, key, value, 0));
+ EXPECT_EQ(WasmResult::Ok, shared_data.get(vm_id, key, &result));
+
+ uint32_t expectedCasValue = result.second;
+
+ std::pair removeResult;
+ EXPECT_EQ(WasmResult::Ok, shared_data.remove(vm_id, key, 0, &removeResult));
+ EXPECT_EQ(value, removeResult.first);
+ EXPECT_EQ(removeResult.second, expectedCasValue);
}
void incrementData(SharedData *shared_data, std::string_view vm_id, std::string_view key) {