diff --git a/recml/inference/benchmarks/DLRM_DCNv2/ckpt_load_and_eval.sh b/recml/inference/benchmarks/DLRM_DCNv2/ckpt_load_and_eval.sh index 02422da..00e68df 100644 --- a/recml/inference/benchmarks/DLRM_DCNv2/ckpt_load_and_eval.sh +++ b/recml/inference/benchmarks/DLRM_DCNv2/ckpt_load_and_eval.sh @@ -6,14 +6,14 @@ export XLA_FLAGS= export TPU_NAME= export LEARNING_RATE=0.0034 -export BATCH_SIZE=135168 +export BATCH_SIZE=4224 export EMBEDDING_SIZE=128 export MODEL_DIR=/tmp/ export FILE_PATTERN=gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/train-* export NUM_STEPS=28000 export CHECKPOINT_INTERVAL=1500 export EVAL_INTERVAL=1500 -export EVAL_FILE_PATTER=gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/eval-* +export EVAL_FILE_PATTERN=gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/eval-* export EVAL_STEPS=660 export MODE=eval export EMBEDDING_THRESHOLD=21000 @@ -21,9 +21,7 @@ export LOGGING_INTERVAL=1500 export RESTORE_CHECKPOINT=true - python recml/inference/models/jax/DLRM_DCNv2/dlrm_main.py \ - --learning_rate=${LEARNING_RATE} \ --batch_size=${BATCH_SIZE} \ --embedding_size=${EMBEDDING_SIZE} \ diff --git a/recml/inference/benchmarks/DLRM_DCNv2/train_and_checkpoint.sh b/recml/inference/benchmarks/DLRM_DCNv2/train_and_checkpoint.sh index e32639c..c3b599f 100644 --- a/recml/inference/benchmarks/DLRM_DCNv2/train_and_checkpoint.sh +++ b/recml/inference/benchmarks/DLRM_DCNv2/train_and_checkpoint.sh @@ -6,14 +6,14 @@ export XLA_FLAGS= export TPU_NAME= export LEARNING_RATE=0.0034 -export BATCH_SIZE=135168 +export BATCH_SIZE=4224 export EMBEDDING_SIZE=128 export MODEL_DIR=/tmp/ export FILE_PATTERN=gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/train-* export NUM_STEPS=28000 export CHECKPOINT_INTERVAL=1500 export EVAL_INTERVAL=1500 -export EVAL_FILE_PATTER=gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/eval-* +export EVAL_FILE_PATTERN=gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/eval-* export EVAL_STEPS=660 export MODE=train export EMBEDDING_THRESHOLD=21000 @@ -21,7 +21,6 @@ export LOGGING_INTERVAL=1500 export RESTORE_CHECKPOINT=true python recml/inference/models/jax/DLRM_DCNv2/dlrm_main.py \ - --learning_rate=${LEARNING_RATE} \ --batch_size=${BATCH_SIZE} \ --embedding_size=${EMBEDDING_SIZE} \ diff --git a/recml/inference/benchmarks/README.md b/recml/inference/benchmarks/README.md index 4c05b5f..98306b2 100644 --- a/recml/inference/benchmarks/README.md +++ b/recml/inference/benchmarks/README.md @@ -54,10 +54,10 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${Z gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE} --worker=all --command="pip install -U tensorflow dm-tree flax google-metrax" ``` -#### Run workload +#### Make script executable and Run workload Note: Please update the MODEL_NAME & TASK_NAME before running the below command ``` -gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE} --worker=all --command="TPU_NAME=${TPU_NAME} ./inference/benchmarks//" +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE} --worker=all --command="cd RecML && chmod +x ./recml/inference/benchmarks// && TPU_NAME=${TPU_NAME} ./recml/inference/benchmarks//" ``` diff --git a/requirements.txt b/requirements.txt index 580d6c9..a82b751 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,7 +63,7 @@ platformdirs==4.3.7 pluggy==1.5.0 pre-commit==4.2.0 promise==2.3 -protobuf==5.29.4 +protobuf==4.21.12 psutil==7.0.0 pyarrow==19.0.1 pygments==2.19.1