Skip to content

Commit d14e7f5

Browse files
committed
Fix all the lint issues.
1 parent 595ead2 commit d14e7f5

14 files changed

+153
-99
lines changed

benchmarks/run_offline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def main(argv):
109109

110110
profiling_output = FLAGS.profiling_output
111111
print("======= decode starting ===")
112-
112+
113113
dec_times = []
114114
for i in range(10):
115115
if profiling_output and i == 7 and not profiler_started:

jetstream_pt/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def attend(xq, keys, values, local_mask=None):
456456
# When GQA is enabled, it not necessary to expand
457457
if not (self.env.ragged_mha and n_rep > 1) and seqlen == 1:
458458
true_len = 2
459-
#xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3]))
459+
# xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3]))
460460
xq = torch.nn.functional.pad(
461461
xq, (0, 0, 0, true_len - seqlen), "constant", 0
462462
)
@@ -714,6 +714,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None):
714714

715715
return attn_out
716716

717+
717718
class Attention(ModuleBase):
718719
"""Attention module."""
719720

run_interactive.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,10 @@
1818
from typing import List
1919

2020
# import torch_xla2 first!
21-
import torch_xla2 # pylint: disable
2221
import jax
2322
import numpy as np
24-
from absl import app, flags
25-
from colorama import Fore, Style
23+
from absl import app
2624
from jetstream.engine import token_utils
27-
from jetstream_pt import engine as je
2825
from jetstream_pt.config import FLAGS, create_engine_from_config_flags
2926

3027

@@ -54,10 +51,15 @@ def main(argv):
5451
if profiling_prefill:
5552
jax.profiler.stop_trace()
5653
prompts: List[str] = [
54+
# pylint: disable-next=all
5755
"I believe the meaning of life is",
56+
# pylint: disable-next=all
5857
"To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList<Person>`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList<Person> peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.",
58+
# pylint: disable-next=all
5959
"<s>[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<</SYS>>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]",
60+
# pylint: disable-next=all
6061
"<s>[INST] <<SYS>>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<</SYS>>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST",
62+
# pylint: disable-next=all
6163
"<s>[INST] <<SYS>>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<</SYS>>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]",
6264
]
6365
for prompt in prompts:

run_interactive_disaggregated.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
from typing import List
2020
from absl import app
2121
from absl import flags
22-
from colorama import Fore, Style
2322

24-
import numpy as np
2523
import jax
2624

2725
from jetstream.engine import token_utils
@@ -129,7 +127,6 @@ def main(argv):
129127
print("Load params ", time.perf_counter() - start)
130128

131129
metadata = prefill_engine.get_tokenizer()
132-
tokenizer = prefill_engine.build_tokenizer(metadata)
133130
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
134131
stop_tokens = [vocab.eos_id, vocab.pad_id]
135132
max_output_length = 1024
@@ -157,19 +154,21 @@ def main(argv):
157154
print(f"---- Input prompts are: {prompt}")
158155
print(f"---- Encoded tokens are: {tokens}")
159156

160-
# pylint: disable-next=all
161157
print(
158+
# pylint: disable-next=all
162159
f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}"
163160
)
164161
prefill_result, _ = prefill_engine.prefill(
165162
params=None, padded_tokens=tokens, true_length=true_length
166163
)
167164
print(
165+
# pylint: disable-next=all
168166
f"---- Transfer prefill result to decode engine pod_slice_name: {decode_engine.pod_slice_name}"
169167
)
170168
decode_engine.transfer(prefill_result)
171-
# pylint: disable-next=all
169+
172170
print(
171+
# pylint: disable-next=all
173172
f"---- Do insert in decode engine pod_slice_name: {decode_engine.pod_slice_name}"
174173
)
175174
decode_state = decode_engine.insert(prefill_result, None, slot=slot)

run_interactive_multiple_host.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import jax
2121
from absl import app, flags
22-
from colorama import Fore, Style
2322
from jetstream.engine import token_utils
2423
from jetstream_pt import ray_engine
2524
from jetstream_pt.config import FLAGS

run_ray_serve_interleave.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141

4242
def create_head_resource_name(generation, tpu_chips):
43+
"""Create head resource name."""
4344
return f"TPU-{generation}-{tpu_chips}-head"
4445

4546

@@ -73,6 +74,7 @@ def create_engine(**kwargs):
7374

7475
@serve.deployment
7576
class JetStreamDeployment:
77+
"""JetStream deployment."""
7678

7779
def __init__(self, **kwargs):
7880
os.environ["XLA_FLAGS"] = (
@@ -111,18 +113,24 @@ def __init__(self, **kwargs):
111113

112114
print("Started jetstream driver....")
113115

116+
# pylint: disable-next=all
114117
async def Decode(
115-
self, request: jetstream_pb2.DecodeRequest
118+
self,
119+
# pylint: disable-next=all
120+
request: jetstream_pb2.DecodeRequest,
121+
# pylint: disable-next=all
116122
) -> AsyncIterator[jetstream_pb2.DecodeResponse]:
117-
123+
"""Async decode function."""
118124
return self.orchestrator.Decode(request)
119125

120126

121127
def main(_argv):
128+
"""Main function"""
122129
resource_name = create_head_resource_name(
123130
FLAGS.tpu_generation, FLAGS.tpu_chips
124131
)
125132
print(f"Using head resource {resource_name}")
133+
# pylint: disable-next=all
126134
deployment = JetStreamDeployment.options(
127135
ray_actor_options={"resources": {resource_name: 1}}
128136
).bind(

run_server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Sequence
1818

1919
# import torch_xla2 first!
20-
import torch_xla2 # pylint: disable
2120
import jax
2221
from absl import app, flags
2322
from jetstream.core import server_lib

run_server_with_ray.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from absl import app, flags
2020

2121
# import torch_xla2 first!
22-
import torch_xla2 # pylint: disable
2322
import jax
2423
from jetstream.core import server_lib
2524
from jetstream.core.config_lib import ServerConfig

tests/helpers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from jetstream_pt import environment
77

88

9+
# pylint: disable-next=all
910
def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None):
1011
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
1112
torch.set_default_dtype(torch_dtype)
@@ -33,6 +34,7 @@ def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None):
3334
return env, config
3435

3536

37+
# pylint: disable-next=all
3638
def make_mixtral_env(bf16_enable=True):
3739
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
3840
torch.set_default_dtype(torch_dtype)
@@ -57,14 +59,16 @@ def make_mixtral_env(bf16_enable=True):
5759
return env, config
5860

5961

62+
# pylint: disable-next=all
6063
def to_xla_tensor(tree):
6164
return torch_xla2.default_env().to_xla(tree)
6265

6366

67+
# pylint: disable-next=all
6468
def call_xla_model(model, weights, args):
6569
with jax.default_device(jax.devices("cpu")[0]):
6670
xla_weights, xla_inputs = to_xla_tensor((weights, args))
6771
with torch_xla2.default_env():
6872
result = torch.func.functional_call(model, xla_weights, xla_inputs)
69-
result_torch = torch_xla2.tensor.j2t(result._elem)
73+
result_torch = torch_xla2.tensor.j2t(result.jax())
7074
return result_torch

tests/test_hf_names.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44

55

66
class TestModuleBase(unittest.TestCase):
7+
"""Test module base."""
78

89
def test_get_hf_names_to_real_name(self):
10+
"""Test get hugginface names to real name."""
911

1012
class MyModule(ModuleBase):
13+
"""My module."""
1114

1215
def __init__(self):
1316
super().__init__()
@@ -18,6 +21,9 @@ def __init__(self):
1821
self.param = torch.nn.Parameter(torch.randn(10))
1922
self.hf_name("param", "model.param")
2023

24+
def forward(self):
25+
"""Forward function."""
26+
2127
module = MyModule()
2228
expected_mapping = {
2329
"model.my_linear1.weight": "linear1.weight",
@@ -30,20 +36,30 @@ def __init__(self):
3036
self.assertEqual(module.get_hf_names_to_real_name(), expected_mapping)
3137

3238
def test_get_sharding_annotations(self):
39+
"""Test get sharding annotations."""
40+
3341
class MyModule(ModuleBase):
42+
"""MyModule."""
3443

3544
def __init__(self):
3645
super().__init__()
3746
self.linear = torch.nn.Linear(10, 20)
3847
self.embedding = torch.nn.Embedding(100, 50)
3948
self.inner = InnerModule()
4049

50+
def forward(self):
51+
"""Forward function."""
52+
4153
class InnerModule(ModuleBase):
54+
"""Inner modeule."""
4255

4356
def __init__(self):
4457
super().__init__()
4558
self.fc = torch.nn.Linear(50, 100)
4659

60+
def forward(self):
61+
"""Forward function."""
62+
4763
module = MyModule()
4864
module.annotate_sharding("linear.weight", 0)
4965
module.annotate_sharding("embedding.weight", 1)

0 commit comments

Comments
 (0)