Skip to content

Commit 5afe9bd

Browse files
authored
Fix WAVM support. (#47)
Signed-off-by: John Plevyak <[email protected]>
1 parent d1e7040 commit 5afe9bd

File tree

2 files changed

+109
-62
lines changed

2 files changed

+109
-62
lines changed

src/wasm.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,9 @@ bool WasmBase::initialize(const std::string &code, bool allow_precompiled) {
341341

342342
if (started_from_ != Cloneable::InstantiatedModule) {
343343
registerCallbacks();
344-
wasm_vm_->link(vm_id_);
344+
if (!wasm_vm_->link(vm_id_)) {
345+
return false;
346+
}
345347
}
346348

347349
vm_context_.reset(createVmContext());

src/wavm/wavm.cc

Lines changed: 106 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
#include "include/proxy-wasm/wavm.h"
1717

18+
#include <cstdlib>
1819
#include <iostream>
20+
#include <map>
1921
#include <memory>
2022
#include <optional>
2123
#include <unordered_map>
@@ -43,8 +45,16 @@
4345
#include "WAVM/Runtime/Runtime.h"
4446
#include "WAVM/WASM/WASM.h"
4547
#include "WAVM/WASTParse/WASTParse.h"
46-
#include "absl/container/node_hash_map.h"
47-
#include "absl/strings/match.h"
48+
49+
#ifdef NDEBUG
50+
#define ASSERT(_x) _x
51+
#else
52+
#define ASSERT(_x) \
53+
do { \
54+
if (!_x) \
55+
::exit(1); \
56+
} while (0)
57+
#endif
4858

4959
using namespace WAVM;
5060
using namespace WAVM::IR;
@@ -74,15 +84,21 @@ struct Wavm;
7484

7585
namespace {
7686

77-
#define CALL_WITH_CONTEXT(_x, _context) \
87+
#define CALL_WITH_CONTEXT(_x, _context, _wavm) \
7888
do { \
79-
SaveRestoreContext _saved_context(static_cast<ContextBase *>(_context)); \
80-
WAVM::Runtime::catchRuntimeExceptions([&] { _x; }, \
81-
[&](WAVM::Runtime::Exception *exception) { \
82-
auto description = describeException(exception); \
83-
destroyException(exception); \
84-
throw WasmException(description); \
85-
}); \
89+
try { \
90+
SaveRestoreContext _saved_context(static_cast<ContextBase *>(_context)); \
91+
WAVM::Runtime::catchRuntimeExceptions( \
92+
[&] { _x; }, \
93+
[&](WAVM::Runtime::Exception *exception) { \
94+
auto description = describeException(exception); \
95+
_wavm->fail(FailState::RuntimeError, \
96+
"Function: " + std::string(function_name) + " failed: " + description); \
97+
destroyException(exception); \
98+
throw std::exception(); \
99+
}); \
100+
} catch (...) { \
101+
} \
86102
} while (0)
87103

88104
struct WasmUntaggedValue : public WAVM::IR::UntaggedValue {
@@ -96,11 +112,9 @@ struct WasmUntaggedValue : public WAVM::IR::UntaggedValue {
96112
WasmUntaggedValue(F64 inF64) { f64 = inF64; }
97113
};
98114

99-
const Logger::Id wasmId = Logger::Id::wasm;
100-
101-
class RootResolver : public WAVM::Runtime::Resolver, public Logger::Loggable<wasmId> {
115+
class RootResolver : public WAVM::Runtime::Resolver {
102116
public:
103-
RootResolver(WAVM::Runtime::Compartment *, WavmVm *vm) : vm_(vm) {}
117+
RootResolver(WAVM::Runtime::Compartment *, WasmVm *vm) : vm_(vm) {}
104118

105119
virtual ~RootResolver() { module_name_to_instance_map_.clear(); }
106120

@@ -113,10 +127,12 @@ class RootResolver : public WAVM::Runtime::Resolver, public Logger::Loggable<was
113127
if (isA(out_object, type)) {
114128
return true;
115129
} else {
116-
vm_->error("Failed to load WASM module due to a type mismatch in an import: " +
117-
std::string(module_name) + "." + export_name + " " +
118-
asString(WAVM::Runtime::getExternType(out_object)) +
119-
" but was expecting type: " + asString(type));
130+
vm_->fail(FailState::UnableToInitializeCode,
131+
"Failed to load WASM module due to a type mismatch in an import: " +
132+
std::string(module_name) + "." + export_name + " " +
133+
asString(WAVM::Runtime::getExternType(out_object)) +
134+
" but was expecting type: " + asString(type));
135+
return false;
120136
}
121137
}
122138
}
@@ -125,19 +141,21 @@ class RootResolver : public WAVM::Runtime::Resolver, public Logger::Loggable<was
125141
return true;
126142
}
127143
}
128-
vm_->error("Failed to load Wasm module due to a missing import: " + std::string(module_name) +
129-
"." + std::string(export_name) + " " + asString(type));
144+
vm_->fail(FailState::MissingFunction,
145+
"Failed to load Wasm module due to a missing import: " + std::string(module_name) +
146+
"." + std::string(export_name) + " " + asString(type));
147+
return false;
130148
}
131149

132-
HashMap<std::string, WAVM::Runtime::ModuleInstance *> &moduleNameToInstanceMap() {
150+
HashMap<std::string, WAVM::Runtime::Instance *> &moduleNameToInstanceMap() {
133151
return module_name_to_instance_map_;
134152
}
135153

136154
void addResolver(WAVM::Runtime::Resolver *r) { resolvers_.push_back(r); }
137155

138156
private:
139-
WavmVm *vm_;
140-
HashMap<std::string, WAVM::Runtime::ModuleInstance *> module_name_to_instance_map_;
157+
WasmVm *vm_;
158+
HashMap<std::string, WAVM::Runtime::Instance *> module_name_to_instance_map_;
141159
std::vector<WAVM::Runtime::Resolver *> resolvers_;
142160
};
143161

@@ -173,23 +191,24 @@ struct PairHash {
173191
}
174192
};
175193

176-
struct Wavm : public WasmVmBase {
177-
Wavm(Stats::ScopeSharedPtr scope) : WasmVmBase(scope, WasmRuntimeNames::get().Wavm) {}
194+
struct Wavm : public WasmVm {
195+
Wavm() : WasmVm() {}
178196
~Wavm() override;
179197

180198
// WasmVm
181-
std::string_view runtime() override { return WasmRuntimeNames::get().Wavm; }
199+
std::string_view runtime() override { return "wavm"; }
182200
Cloneable cloneable() override { return Cloneable::InstantiatedModule; };
183201
std::unique_ptr<WasmVm> clone() override;
184202
bool load(const std::string &code, bool allow_precompiled) override;
185-
void link(std::string_view debug_name) override;
203+
bool link(std::string_view debug_name) override;
186204
uint64_t getMemorySize() override;
187205
std::optional<std::string_view> getMemory(uint64_t pointer, uint64_t size) override;
188206
bool setMemory(uint64_t pointer, uint64_t size, const void *data) override;
189207
bool getWord(uint64_t pointer, Word *data) override;
190208
bool setWord(uint64_t pointer, Word data) override;
191209
std::string_view getCustomSection(std::string_view name) override;
192210
std::string_view getPrecompiledSectionName() override;
211+
AbiVersion getAbiVersion() override;
193212

194213
#define _GET_FUNCTION(_T) \
195214
void getFunction(std::string_view function_name, _T *f) override { \
@@ -209,15 +228,16 @@ struct Wavm : public WasmVmBase {
209228
bool has_instantiated_module_ = false;
210229
IR::Module ir_module_;
211230
WAVM::Runtime::ModuleRef module_ = nullptr;
212-
WAVM::Runtime::GCPointer<WAVM::Runtime::ModuleInstance> module_instance_;
231+
WAVM::Runtime::GCPointer<WAVM::Runtime::Instance> module_instance_;
213232
WAVM::Runtime::Memory *memory_;
214233
WAVM::Runtime::GCPointer<WAVM::Runtime::Compartment> compartment_;
215234
WAVM::Runtime::GCPointer<WAVM::Runtime::Context> context_;
216-
node_hash_map<std::string, Intrinsics::Module> intrinsic_modules_;
217-
node_hash_map<std::string, WAVM::Runtime::GCPointer<WAVM::Runtime::ModuleInstance>>
235+
std::map<std::string, Intrinsics::Module> intrinsic_modules_;
236+
std::map<std::string, WAVM::Runtime::GCPointer<WAVM::Runtime::Instance>>
218237
intrinsic_module_instances_;
219238
std::vector<std::unique_ptr<Intrinsics::Function>> envoyFunctions_;
220239
uint8_t *memory_base_ = nullptr;
240+
AbiVersion abi_version_ = AbiVersion::Unknown;
221241
};
222242

223243
Wavm::~Wavm() {
@@ -232,11 +252,12 @@ Wavm::~Wavm() {
232252
}
233253

234254
std::unique_ptr<WasmVm> Wavm::clone() {
235-
auto wavm = std::make_unique<Wavm>(scope_);
255+
auto wavm = std::make_unique<Wavm>();
236256
wavm->compartment_ = WAVM::Runtime::cloneCompartment(compartment_);
237257
wavm->memory_ = WAVM::Runtime::remapToClonedCompartment(memory_, wavm->compartment_);
238258
wavm->memory_base_ = WAVM::Runtime::getMemoryBaseAddress(wavm->memory_);
239259
wavm->context_ = WAVM::Runtime::createContext(wavm->compartment_);
260+
wavm->abi_version_ = abi_version_;
240261
for (auto &p : intrinsic_module_instances_) {
241262
wavm->intrinsic_module_instances_.emplace(
242263
p.first, WAVM::Runtime::remapToClonedCompartment(p.second, wavm->compartment_));
@@ -254,7 +275,7 @@ bool Wavm::load(const std::string &code, bool allow_precompiled) {
254275
if (!loadModule(code, ir_module_)) {
255276
return false;
256277
}
257-
// todo check percompiled section is permitted
278+
getAbiVersion(); // Cache ABI version.
258279
const CustomSection *precompiled_object_section = nullptr;
259280
if (allow_precompiled) {
260281
for (const CustomSection &customSection : ir_module_.customSections) {
@@ -272,21 +293,48 @@ bool Wavm::load(const std::string &code, bool allow_precompiled) {
272293
return true;
273294
}
274295

275-
AbiVersion Wavm::getAbiVersion() { return AbiVersion::Unknown; }
296+
AbiVersion Wavm::getAbiVersion() {
297+
if (abi_version_ != AbiVersion::Unknown) {
298+
return abi_version_;
299+
}
300+
for (auto &e : ir_module_.exports) {
301+
if (e.name == "proxy_abi_version_0_1_0") {
302+
abi_version_ = AbiVersion::ProxyWasm_0_1_0;
303+
return abi_version_;
304+
}
305+
if (e.name == "proxy_abi_version_0_2_0") {
306+
abi_version_ = AbiVersion::ProxyWasm_0_2_0;
307+
return abi_version_;
308+
}
309+
if (e.name == "proxy_abi_version_0_2_1") {
310+
abi_version_ = AbiVersion::ProxyWasm_0_2_1;
311+
return abi_version_;
312+
}
313+
}
314+
return AbiVersion::Unknown;
315+
}
276316

277-
void Wavm::link(std::string_view debug_name) {
278-
RootResolver rootResolver(compartment_);
317+
bool Wavm::link(std::string_view debug_name) {
318+
RootResolver rootResolver(compartment_, this);
279319
for (auto &p : intrinsic_modules_) {
280320
auto instance = Intrinsics::instantiateModule(compartment_, {&intrinsic_modules_[p.first]},
281321
std::string(p.first));
282322
intrinsic_module_instances_.emplace(p.first, instance);
283323
rootResolver.moduleNameToInstanceMap().set(p.first, instance);
284324
}
285325
WAVM::Runtime::LinkResult link_result = linkModule(ir_module_, rootResolver);
326+
if (!link_result.missingImports.empty()) {
327+
for (auto &i : link_result.missingImports) {
328+
error("Missing Wasm import " + i.moduleName + " " + i.exportName);
329+
}
330+
fail(FailState::MissingFunction, "Failed to load Wasm module due to a missing import(s)");
331+
return false;
332+
}
286333
module_instance_ = instantiateModule(
287334
compartment_, module_, std::move(link_result.resolvedImports), std::string(debug_name));
288335
memory_ = getDefaultMemory(module_instance_);
289336
memory_base_ = WAVM::Runtime::getMemoryBaseAddress(memory_);
337+
return true;
290338
}
291339

292340
uint64_t Wavm::getMemorySize() { return WAVM::Runtime::getMemoryNumPages(memory_) * WasmPageSize; }
@@ -326,7 +374,7 @@ bool Wavm::setWord(uint64_t pointer, Word data) {
326374
return setMemory(pointer, sizeof(uint32_t), &data32);
327375
}
328376

329-
std::string_view Wavm::getCustomSection(string_view name) {
377+
std::string_view Wavm::getCustomSection(std::string_view name) {
330378
for (auto &section : ir_module_.customSections) {
331379
if (section.name == name) {
332380
return {reinterpret_cast<char *>(section.data.data()), section.data.size()};
@@ -337,12 +385,10 @@ std::string_view Wavm::getCustomSection(string_view name) {
337385

338386
std::string_view Wavm::getPrecompiledSectionName() { return "wavm.precompiled_object"; }
339387

340-
std::unique_ptr<WasmVm> createVm(Stats::ScopeSharedPtr scope) {
341-
return std::make_unique<Wavm>(scope);
342-
}
343-
344388
} // namespace Wavm
345389

390+
std::unique_ptr<WasmVm> createWavmVm() { return std::make_unique<proxy_wasm::Wavm::Wavm>(); }
391+
346392
template <typename R, typename... Args>
347393
IR::FunctionType inferEnvoyFunctionType(R (*)(void *, Args...)) {
348394
return IR::FunctionType(IR::inferResultType<R>(), IR::TypeTuple({IR::inferValueType<Args>()...}),
@@ -354,10 +400,10 @@ using namespace Wavm;
354400
template <typename R, typename... Args>
355401
void registerCallbackWavm(WasmVm *vm, std::string_view module_name, std::string_view function_name,
356402
R (*f)(Args...)) {
357-
auto wavm = static_cast<Wavm *>(vm);
358-
wavm->envoyFunctions_.emplace_back(
359-
new Intrinsics::Function(&wavm->intrinsic_modules_[module_name], function_name.data(),
360-
reinterpret_cast<void *>(f), inferEnvoyFunctionType(f)));
403+
auto wavm = static_cast<proxy_wasm::Wavm::Wavm *>(vm);
404+
wavm->envoyFunctions_.emplace_back(new Intrinsics::Function(
405+
&wavm->intrinsic_modules_[std::string(module_name)], function_name.data(),
406+
reinterpret_cast<void *>(f), inferEnvoyFunctionType(f)));
361407
}
362408

363409
template void registerCallbackWavm<void, void *>(WasmVm *vm, std::string_view module_name,
@@ -452,7 +498,7 @@ static bool checkFunctionType(WAVM::Runtime::Function *f, IR::FunctionType t) {
452498
template <typename R, typename... Args>
453499
void getFunctionWavmReturn(WasmVm *vm, std::string_view function_name,
454500
std::function<R(ContextBase *, Args...)> *function, uint32_t) {
455-
auto wavm = static_cast<proxy_wasm::Wavm *>(vm);
501+
auto wavm = static_cast<proxy_wasm::Wavm::Wavm *>(vm);
456502
auto f =
457503
asFunctionNullable(getInstanceExport(wavm->module_instance_, std::string(function_name)));
458504
if (!f)
@@ -462,18 +508,19 @@ void getFunctionWavmReturn(WasmVm *vm, std::string_view function_name,
462508
return;
463509
}
464510
if (!checkFunctionType(f, inferStdFunctionType(function))) {
465-
error("Bad function signature for: " + std::string(function_name));
511+
wavm->fail(FailState::UnableToInitializeCode,
512+
"Bad function signature for: " + std::string(function_name));
466513
}
467-
*function = [wavm, f, function_name, this](ContextBase *context, Args... args) -> R {
514+
*function = [wavm, f, function_name](ContextBase *context, Args... args) -> R {
468515
WasmUntaggedValue values[] = {args...};
469516
WasmUntaggedValue return_value;
470-
try {
471-
CALL_WITH_CONTEXT(
472-
invokeFunction(wavm->context_, f, getFunctionType(f), &values[0], &return_value),
473-
context);
517+
CALL_WITH_CONTEXT(
518+
invokeFunction(wavm->context_, f, getFunctionType(f), &values[0], &return_value), context,
519+
wavm);
520+
if (!wavm->isFailed()) {
474521
return static_cast<uint32_t>(return_value.i32);
475-
} catch (const std::exception &e) {
476-
error("Function: " + std::string(function_name) + " failed: " + e.what());
522+
} else {
523+
return 0;
477524
}
478525
};
479526
}
@@ -483,7 +530,7 @@ struct Void {};
483530
template <typename R, typename... Args>
484531
void getFunctionWavmReturn(WasmVm *vm, std::string_view function_name,
485532
std::function<R(ContextBase *, Args...)> *function, Void) {
486-
auto wavm = static_cast<proxy_wasm::Wavm *>(vm);
533+
auto wavm = static_cast<proxy_wasm::Wavm::Wavm *>(vm);
487534
auto f =
488535
asFunctionNullable(getInstanceExport(wavm->module_instance_, std::string(function_name)));
489536
if (!f)
@@ -493,15 +540,13 @@ void getFunctionWavmReturn(WasmVm *vm, std::string_view function_name,
493540
return;
494541
}
495542
if (!checkFunctionType(f, inferStdFunctionType(function))) {
496-
vm->error("Bad function signature for: " + std::string(function_name));
543+
wavm->fail(FailState::UnableToInitializeCode,
544+
"Bad function signature for: " + std::string(function_name));
497545
}
498-
*function = [wavm, f, function_name, this](ContextBase *context, Args... args) -> R {
546+
*function = [wavm, f, function_name](ContextBase *context, Args... args) -> R {
499547
WasmUntaggedValue values[] = {args...};
500-
try {
501-
CALL_WITH_CONTEXT(invokeFunction(wavm->context_, f, getFunctionType(f), &values[0]), context);
502-
} catch (const std::exception &e) {
503-
error("Function: " + std::string(function_name) + " failed: " + e.what());
504-
}
548+
CALL_WITH_CONTEXT(invokeFunction(wavm->context_, f, getFunctionType(f), &values[0]), context,
549+
wavm);
505550
};
506551
}
507552

0 commit comments

Comments
 (0)