Skip to content

[Strings] Add a string lowering pass using magic imports #6497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2838,8 +2838,9 @@ void PrintSExpression::handleSignature(HeapType curr, Name name) {
void PrintSExpression::visitExport(Export* curr) {
o << '(';
printMedium(o, "export ");
// TODO: Escape the string properly.
printText(o, curr->name.str.data()) << " (";
std::stringstream escaped;
String::printEscaped(escaped, curr->name.str);
printText(o, escaped.str(), false) << " (";
switch (curr->kind) {
case ExternalKind::Function:
o << "func";
Expand All @@ -2865,9 +2866,11 @@ void PrintSExpression::visitExport(Export* curr) {

void PrintSExpression::emitImportHeader(Importable* curr) {
printMedium(o, "import ");
// TODO: Escape the strings properly and use std::string_view.
printText(o, curr->module.str.data()) << ' ';
printText(o, curr->base.str.data()) << ' ';
std::stringstream escapedModule, escapedBase;
String::printEscaped(escapedModule, curr->module.str);
String::printEscaped(escapedBase, curr->base.str);
printText(o, escapedModule.str(), false) << ' ';
printText(o, escapedBase.str(), false) << ' ';
}

void PrintSExpression::visitGlobal(Global* curr) {
Expand Down
35 changes: 24 additions & 11 deletions src/passes/StringLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ struct StringGathering : public Pass {
};

struct StringLowering : public StringGathering {
// If true, then encode well-formed strings as (import "'" "string...")
// instead of emitting them into the JSON custom section.
bool useMagicImports;

StringLowering(bool useMagicImports = false)
: useMagicImports(useMagicImports) {}

void run(Module* module) override {
if (!module->features.has(FeatureSet::Strings)) {
return;
Expand Down Expand Up @@ -217,25 +224,30 @@ struct StringLowering : public StringGathering {
}

void makeImports(Module* module) {
Index importIndex = 0;
Index jsonImportIndex = 0;
std::stringstream json;
json << '[';
bool first = true;
std::vector<Name> importedStrings;
for (auto& global : module->globals) {
if (global->init) {
if (auto* c = global->init->dynCast<StringConst>()) {
global->module = "string.const";
global->base = std::to_string(importIndex);
importIndex++;
global->init = nullptr;

if (first) {
first = false;
std::stringstream utf8;
if (useMagicImports &&
String::convertUTF16ToUTF8(utf8, c->string.str)) {
global->module = "'";
global->base = Name(utf8.str());
} else {
json << ',';
global->module = "string.const";
global->base = std::to_string(jsonImportIndex);
if (first) {
first = false;
} else {
json << ',';
}
String::printEscapedJSON(json, c->string.str);
jsonImportIndex++;
}
String::printEscapedJSON(json, c->string.str);
global->init = nullptr;
}
}
}
Expand Down Expand Up @@ -516,5 +528,6 @@ struct StringLowering : public StringGathering {

Pass* createStringGatheringPass() { return new StringGathering(); }
Pass* createStringLoweringPass() { return new StringLowering(); }
Pass* createStringLoweringMagicImportPass() { return new StringLowering(true); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Pass* createStringLoweringMagicImportPass() { return new StringLowering(true); }
Pass* createMagicStringLoweringPass() { return new StringLowering(true); }

/jk

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🪄 ✨


} // namespace wasm
4 changes: 4 additions & 0 deletions src/passes/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,10 @@ void PassRegistry::registerPasses() {
registerPass("string-lowering",
"lowers wasm strings and operations to imports",
createStringLoweringPass);
registerPass(
"string-lowering-magic-imports",
"same as string-lowering, but encodes well-formed strings as magic imports",
createStringLoweringMagicImportPass);
registerPass(
"strip", "deprecated; same as strip-debug", createStripDebugPass);
registerPass("stack-check",
Expand Down
1 change: 1 addition & 0 deletions src/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ Pass* createSimplifyLocalsNoTeeNoStructurePass();
Pass* createStackCheckPass();
Pass* createStringGatheringPass();
Pass* createStringLoweringPass();
Pass* createStringLoweringMagicImportPass();
Pass* createStripDebugPass();
Pass* createStripDWARFPass();
Pass* createStripProducersPass();
Expand Down
18 changes: 12 additions & 6 deletions src/pretty_printing.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,35 @@ inline std::ostream& restoreNormalColor(std::ostream& o) {
return o;
}

inline std::ostream& printText(std::ostream& o, const char* str) {
o << '"';
inline std::ostream&
printText(std::ostream& o, std::string_view str, bool needQuotes = true) {
if (needQuotes) {
o << '"';
}
Colors::green(o);
o << str;
Colors::normal(o);
return o << '"';
if (needQuotes) {
o << '"';
}
return o;
}

inline std::ostream& printMajor(std::ostream& o, const char* str) {
inline std::ostream& printMajor(std::ostream& o, std::string_view str) {
prepareMajorColor(o);
o << str;
restoreNormalColor(o);
return o;
}

inline std::ostream& printMedium(std::ostream& o, const char* str) {
inline std::ostream& printMedium(std::ostream& o, std::string_view str) {
prepareColor(o);
o << str;
restoreNormalColor(o);
return o;
}

inline std::ostream& printMinor(std::ostream& o, const char* str) {
inline std::ostream& printMinor(std::ostream& o, std::string_view str) {
prepareMinorColor(o);
o << str;
restoreNormalColor(o);
Expand Down
41 changes: 29 additions & 12 deletions src/support/string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ std::optional<uint16_t> takeWTF16CodeUnit(std::string_view& str) {
return u;
}

std::optional<uint32_t> takeWTF16CodePoint(std::string_view& str) {
std::optional<uint32_t> takeWTF16CodePoint(std::string_view& str,
bool allowWTF = true) {
auto u = takeWTF16CodeUnit(str);
if (!u) {
return std::nullopt;
Expand All @@ -228,7 +229,13 @@ std::optional<uint32_t> takeWTF16CodePoint(std::string_view& str) {
uint16_t highBits = *u - 0xD800;
uint16_t lowBits = *low - 0xDC00;
return 0x10000 + ((highBits << 10) | lowBits);
} else if (!allowWTF) {
// Unpaired high surrogate.
return std::nullopt;
}
} else if (!allowWTF && 0xDC00 <= *u && *u < 0xE000) {
// Unpaired low surrogate.
return std::nullopt;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these bugfixes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is newly necessary to catch strings that are valid WTF-16 but not valid UTF-16.

}

return *u;
Expand All @@ -242,6 +249,23 @@ void writeWTF16CodeUnit(std::ostream& os, uint16_t u) {

constexpr uint32_t replacementCharacter = 0xFFFD;

bool doConvertWTF16ToWTF8(std::ostream& os,
std::string_view str,
bool allowWTF) {
bool valid = true;

while (str.size()) {
auto u = takeWTF16CodePoint(str, allowWTF);
if (!u) {
valid = false;
u = replacementCharacter;
}
writeWTF8CodePoint(os, *u);
}

return valid;
}

} // anonymous namespace

std::ostream& writeWTF8CodePoint(std::ostream& os, uint32_t u) {
Expand Down Expand Up @@ -308,18 +332,11 @@ bool convertWTF8ToWTF16(std::ostream& os, std::string_view str) {
}

bool convertWTF16ToWTF8(std::ostream& os, std::string_view str) {
bool valid = true;

while (str.size()) {
auto u = takeWTF16CodePoint(str);
if (!u) {
valid = false;
u = replacementCharacter;
}
writeWTF8CodePoint(os, *u);
}
return doConvertWTF16ToWTF8(os, str, true);
}

return valid;
bool convertUTF16ToUTF8(std::ostream& os, std::string_view str) {
return doConvertWTF16ToWTF8(os, str, false);
}

std::ostream& printEscapedJSON(std::ostream& os, std::string_view str) {
Expand Down
5 changes: 5 additions & 0 deletions src/support/string.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ bool convertWTF8ToWTF16(std::ostream& os, std::string_view str);
// Returns `true` iff the input was valid WTF-16.
bool convertWTF16ToWTF8(std::ostream& os, std::string_view str);

// Writes the UTF-8 encoding of the given UTF-16LE string to `os`, inserting a
// replacement character in place of any unpaired surrogate or incomplete code
// unit. Returns `true` if the input was valid UTF-16.
bool convertUTF16ToUTF8(std::ostream& os, std::string_view str);

} // namespace wasm::String

#endif // wasm_support_string_h
19 changes: 13 additions & 6 deletions src/wasm/wasm-s-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3622,7 +3622,9 @@ void SExpressionWasmBuilder::parseInnerData(Element& s,

void SExpressionWasmBuilder::parseExport(Element& s) {
std::unique_ptr<Export> ex = std::make_unique<Export>();
ex->name = s[1]->str();
std::vector<char> nameBytes;
stringToBinary(*s[1], s[1]->str().str, nameBytes);
ex->name = std::string(nameBytes.data(), nameBytes.size());
if (s[2]->isList()) {
auto& inner = *s[2];
if (elementStartsWith(inner, FUNC)) {
Expand Down Expand Up @@ -3703,15 +3705,20 @@ void SExpressionWasmBuilder::parseImport(Element& s) {
if (!newStyle) {
kind = ExternalKind::Function;
}
auto module = s[i++]->str();
std::vector<char> moduleBytes;
stringToBinary(*s[i], s[i]->str().str, moduleBytes);
Name module = std::string(moduleBytes.data(), moduleBytes.size());
i++;

if (!s[i]->isStr()) {
throw SParseException("no name for import", s, *s[i]);
}
auto base = s[i]->str();
if (!module.size() || !base.size()) {
throw SParseException("imports must have module and base", s, *s[i]);
}

std::vector<char> baseBytes;
stringToBinary(*s[i], s[i]->str().str, baseBytes);
Name base = std::string(baseBytes.data(), baseBytes.size());
i++;

// parse internals
Element& inner = newStyle ? *s[3] : s;
Index j = newStyle ? newStyleInner : i;
Expand Down
4 changes: 4 additions & 0 deletions test/lit/help/wasm-opt.test
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,10 @@
;; CHECK-NEXT: --string-lowering lowers wasm strings and
;; CHECK-NEXT: operations to imports
;; CHECK-NEXT:
;; CHECK-NEXT: --string-lowering-magic-imports same as string-lowering, but
;; CHECK-NEXT: encodes well-formed strings as
;; CHECK-NEXT: magic imports
;; CHECK-NEXT:
;; CHECK-NEXT: --strip deprecated; same as strip-debug
;; CHECK-NEXT:
;; CHECK-NEXT: --strip-debug strip debug info (including the
Expand Down
4 changes: 4 additions & 0 deletions test/lit/help/wasm2js.test
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,10 @@
;; CHECK-NEXT: --string-lowering lowers wasm strings and
;; CHECK-NEXT: operations to imports
;; CHECK-NEXT:
;; CHECK-NEXT: --string-lowering-magic-imports same as string-lowering, but
;; CHECK-NEXT: encodes well-formed strings as
;; CHECK-NEXT: magic imports
;; CHECK-NEXT:
;; CHECK-NEXT: --strip deprecated; same as strip-debug
;; CHECK-NEXT:
;; CHECK-NEXT: --strip-debug strip debug info (including the
Expand Down
86 changes: 86 additions & 0 deletions test/lit/passes/string-lowering-imports.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited.

;; RUN: wasm-opt %s -all --string-lowering-magic-imports --remove-unused-module-elements -S -o - | filecheck %s
;; RUN: wasm-opt %s -all --string-lowering-magic-imports --remove-unused-module-elements --roundtrip -S -o - | filecheck %s --check-prefix=RTRIP

(module
;; CHECK: (type $0 (func))

;; CHECK: (import "\'" "bar" (global $string.const_bar (ref extern)))

;; CHECK: (import "\'" "foo" (global $string.const_foo (ref extern)))

;; CHECK: (import "\'" "needs\tescaping\00.\'#%- .\r\n\\08\0c\n\r\t.\ea\99\ae" (global $"string.const_needs\tescaping\00.\'#%- .\r\n\\08\0c\n\r\t.\ea\99\ae" (ref extern)))

;; CHECK: (import "string.const" "0" (global $"string.const_unpaired high surrogate \ed\a0\80 " (ref extern)))

;; CHECK: (import "string.const" "1" (global $"string.const_unpaired low surrogate \ed\bd\88 " (ref extern)))

;; CHECK: (export "consts" (func $consts))

;; CHECK: (func $consts (type $0)
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (global.get $string.const_foo)
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (global.get $string.const_bar)
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (global.get $"string.const_needs\tescaping\00.\'#%- .\r\n\\08\0c\n\r\t.\ea\99\ae")
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (global.get $"string.const_unpaired high surrogate \ed\a0\80 ")
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (global.get $"string.const_unpaired low surrogate \ed\bd\88 ")
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; RTRIP: (type $0 (func))

;; RTRIP: (import "\'" "bar" (global $gimport$0 (ref extern)))

;; RTRIP: (import "\'" "foo" (global $gimport$1 (ref extern)))

;; RTRIP: (import "\'" "needs\tescaping\00.\'#%- .\r\n\\08\0c\n\r\t.\ea\99\ae" (global $gimport$2 (ref extern)))

;; RTRIP: (import "string.const" "0" (global $gimport$3 (ref extern)))

;; RTRIP: (import "string.const" "1" (global $gimport$4 (ref extern)))

;; RTRIP: (export "consts" (func $consts))

;; RTRIP: (func $consts (type $0)
;; RTRIP-NEXT: (drop
;; RTRIP-NEXT: (global.get $gimport$1)
;; RTRIP-NEXT: )
;; RTRIP-NEXT: (drop
;; RTRIP-NEXT: (global.get $gimport$0)
;; RTRIP-NEXT: )
;; RTRIP-NEXT: (drop
;; RTRIP-NEXT: (global.get $gimport$2)
;; RTRIP-NEXT: )
;; RTRIP-NEXT: (drop
;; RTRIP-NEXT: (global.get $gimport$3)
;; RTRIP-NEXT: )
;; RTRIP-NEXT: (drop
;; RTRIP-NEXT: (global.get $gimport$4)
;; RTRIP-NEXT: )
;; RTRIP-NEXT: )
(func $consts (export "consts")
(drop
(string.const "foo")
)
(drop
(string.const "bar")
)
(drop
(string.const "needs\tescaping\00.'#%- .\r\n\\08\0C\0A\0D\09.ꙮ")
)
(drop
(string.const "unpaired high surrogate \ED\A0\80 ")
)
(drop
(string.const "unpaired low surrogate \ED\BD\88 ")
)
)
)
Loading