Skip to content

Commit 6cb17a8

Browse files
committed
add options to disable file generation
1 parent 71f51ce commit 6cb17a8

File tree

4 files changed

+146
-49
lines changed

4 files changed

+146
-49
lines changed

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -469,17 +469,18 @@ def _save_report(
469469
result_df: pd.DataFrame,
470470
metrics_df: pd.DataFrame,
471471
test_metrics_df: pd.DataFrame,
472-
# test_data: pd.DataFrame,
472+
test_data: pd.DataFrame,
473473
):
474474
"""Saves resulting reports to the given folder."""
475475

476476
unique_output_dir = self.spec.output_directory.url
477477
results = ForecastResults()
478478

479-
if ObjectStorageDetails.is_oci_path(unique_output_dir):
480-
storage_options = default_signer()
481-
else:
482-
storage_options = {}
479+
storage_options = (
480+
default_signer()
481+
if ObjectStorageDetails.is_oci_path(unique_output_dir)
482+
else {}
483+
)
483484

484485
# report-creator html report
485486
if self.spec.generate_report:
@@ -510,12 +511,13 @@ def _save_report(
510511
if self.target_cat_col
511512
else result_df.drop(DataColumns.Series, axis=1)
512513
)
513-
write_data(
514-
data=result_df,
515-
filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
516-
format="csv",
517-
storage_options=storage_options,
518-
)
514+
if self.spec.generate_forecast_file:
515+
write_data(
516+
data=result_df,
517+
filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
518+
format="csv",
519+
storage_options=storage_options,
520+
)
519521
results.set_forecast(result_df)
520522

521523
# metrics csv report
@@ -529,15 +531,16 @@ def _save_report(
529531
metrics_df_formatted = metrics_df.reset_index().rename(
530532
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
531533
)
532-
write_data(
533-
data=metrics_df_formatted,
534-
filename=os.path.join(
535-
unique_output_dir, self.spec.metrics_filename
536-
),
537-
format="csv",
538-
storage_options=storage_options,
539-
index=False,
540-
)
534+
if self.spec.generate_metrics_file:
535+
write_data(
536+
data=metrics_df_formatted,
537+
filename=os.path.join(
538+
unique_output_dir, self.spec.metrics_filename
539+
),
540+
format="csv",
541+
storage_options=storage_options,
542+
index=False,
543+
)
541544
results.set_metrics(metrics_df_formatted)
542545
else:
543546
logger.warn(
@@ -550,15 +553,16 @@ def _save_report(
550553
test_metrics_df_formatted = test_metrics_df.reset_index().rename(
551554
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
552555
)
553-
write_data(
554-
data=test_metrics_df_formatted,
555-
filename=os.path.join(
556-
unique_output_dir, self.spec.test_metrics_filename
557-
),
558-
format="csv",
559-
storage_options=storage_options,
560-
index=False,
561-
)
556+
if self.spec.generate_metrics_file:
557+
write_data(
558+
data=test_metrics_df_formatted,
559+
filename=os.path.join(
560+
unique_output_dir, self.spec.test_metrics_filename
561+
),
562+
format="csv",
563+
storage_options=storage_options,
564+
index=False,
565+
)
562566
results.set_test_metrics(test_metrics_df_formatted)
563567
else:
564568
logger.warn(
@@ -568,31 +572,33 @@ def _save_report(
568572
if self.spec.generate_explanations:
569573
try:
570574
if not self.formatted_global_explanation.empty:
571-
write_data(
572-
data=self.formatted_global_explanation,
573-
filename=os.path.join(
574-
unique_output_dir, self.spec.global_explanation_filename
575-
),
576-
format="csv",
577-
storage_options=storage_options,
578-
index=True,
579-
)
575+
if not self.spec.generate_explanations_file:
576+
write_data(
577+
data=self.formatted_global_explanation,
578+
filename=os.path.join(
579+
unique_output_dir, self.spec.global_explanation_filename
580+
),
581+
format="csv",
582+
storage_options=storage_options,
583+
index=True,
584+
)
580585
results.set_global_explanations(self.formatted_global_explanation)
581586
else:
582587
logger.warn(
583588
f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations."
584589
)
585590

586591
if not self.formatted_local_explanation.empty:
587-
write_data(
588-
data=self.formatted_local_explanation,
589-
filename=os.path.join(
590-
unique_output_dir, self.spec.local_explanation_filename
591-
),
592-
format="csv",
593-
storage_options=storage_options,
594-
index=True,
595-
)
592+
if not self.spec.generate_explanations_file:
593+
write_data(
594+
data=self.formatted_local_explanation,
595+
filename=os.path.join(
596+
unique_output_dir, self.spec.local_explanation_filename
597+
),
598+
format="csv",
599+
storage_options=storage_options,
600+
index=True,
601+
)
596602
results.set_local_explanations(self.formatted_local_explanation)
597603
else:
598604
logger.warn(

ads/opctl/operator/lowcode/forecast/operator_config.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,23 @@
1818

1919
from .const import SpeedAccuracyMode, SupportedMetrics, SupportedModels
2020

21+
2122
@dataclass
2223
class AutoScaling(DataClassSerializable):
2324
"""Class representing simple autoscaling policy"""
25+
2426
minimum_instance: int = 1
2527
maximum_instance: int = None
2628
cool_down_in_seconds: int = 600
2729
scale_in_threshold: int = 10
2830
scale_out_threshold: int = 80
2931
scaling_metric: str = "CPU_UTILIZATION"
3032

33+
3134
@dataclass(repr=True)
3235
class ModelDeploymentServer(DataClassSerializable):
3336
"""Class representing model deployment server specification for whatif-analysis."""
37+
3438
display_name: str = None
3539
initial_shape: str = None
3640
description: str = None
@@ -42,10 +46,13 @@ class ModelDeploymentServer(DataClassSerializable):
4246
@dataclass(repr=True)
4347
class WhatIfAnalysis(DataClassSerializable):
4448
"""Class representing operator specification for whatif-analysis."""
49+
4550
model_display_name: str = None
4651
compartment_id: str = None
4752
project_id: str = None
48-
model_deployment: ModelDeploymentServer = field(default_factory=ModelDeploymentServer)
53+
model_deployment: ModelDeploymentServer = field(
54+
default_factory=ModelDeploymentServer
55+
)
4956

5057

5158
@dataclass(repr=True)
@@ -106,8 +113,11 @@ class ForecastOperatorSpec(DataClassSerializable):
106113
datetime_column: DateTimeColumn = field(default_factory=DateTimeColumn)
107114
target_category_columns: List[str] = field(default_factory=list)
108115
generate_report: bool = None
116+
generate_forecast_file: bool = None
109117
generate_metrics: bool = None
118+
generate_metrics_file: bool = None
110119
generate_explanations: bool = None
120+
generate_explanations_file: bool = None
111121
explanations_accuracy_mode: str = None
112122
horizon: int = None
113123
model: str = None
@@ -126,7 +136,9 @@ def __post_init__(self):
126136
self.output_directory = self.output_directory or OutputDirectory(
127137
url=find_output_dirname(self.output_directory)
128138
)
129-
self.generate_model_pickle = True if self.generate_model_pickle or self.what_if_analysis else False
139+
self.generate_model_pickle = (
140+
True if self.generate_model_pickle or self.what_if_analysis else False
141+
)
130142
self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
131143
self.model = self.model or SupportedModels.Prophet
132144
self.confidence_interval_width = self.confidence_interval_width or 0.80
@@ -144,6 +156,21 @@ def __post_init__(self):
144156
self.generate_metrics = (
145157
self.generate_metrics if self.generate_metrics is not None else True
146158
)
159+
self.generate_metrics_file = (
160+
self.generate_metrics_file
161+
if self.generate_metrics_file is not None
162+
else True
163+
)
164+
self.generate_forecast_file = (
165+
self.generate_forecast_file
166+
if self.generate_forecast_file is not None
167+
else True
168+
)
169+
self.generate_explanations_file = (
170+
self.generate_explanations_file
171+
if self.generate_explanations_file is not None
172+
else True
173+
)
147174
# For Explanations Generation. When user doesn't specify defaults to False
148175
self.generate_explanations = (
149176
self.generate_explanations

docs/source/user_guide/operators/forecast_operator/development.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,35 @@ Before running operators on a job, users must configure their output directory.
125125
horizon: 3
126126
target_column: y
127127
128+
129+
Exclude Writing Certain Output Files
130+
====================================
131+
132+
You can choose to exclude certain files from being written to the output folder. This may be because you are calling the API, and not using the output folder. The yaml options below are ``True`` by default, but can be set to ``False`` to prevent file generation.
133+
134+
.. code-block:: yaml
135+
136+
kind: operator
137+
type: forecast
138+
version: v1
139+
spec:
140+
datetime_column:
141+
name: ds
142+
historical_data:
143+
url: oci://<bucket_name>@<namespace_name>/example_yosemite_temps.csv
144+
output_directory:
145+
url: oci://<bucket_name>@<namespace_name>/my_results/
146+
horizon: 3
147+
target_column: y
148+
generate_report: True
149+
generate_forecast_file: False
150+
generate_metrics_file: False
151+
generate_explanations: True
152+
generate_explanations_file: False
153+
154+
The above example will save a report.html to ``oci://<bucket_name>@<namespace_name>/my_results/``, but it will NOT save other files.
155+
156+
128157
Ingesting and Interpreting Outputs
129158
==================================
130159

tests/operators/forecast/test_errors.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,5 +921,40 @@ def test_report_title(operator_setup, model):
921921
assert False, "Report Title was not set"
922922

923923

924+
@pytest.mark.parametrize("model", ["prophet"])
925+
def test_generate_files(operator_setup, model):
926+
from ads.opctl.operator.lowcode.forecast.__main__ import operate
927+
from ads.opctl.operator.lowcode.forecast.operator_config import (
928+
ForecastOperatorConfig,
929+
)
930+
931+
yaml_i = TEMPLATE_YAML.copy()
932+
yaml_i["spec"]["horizon"] = 10
933+
yaml_i["spec"]["model"] = model
934+
yaml_i["spec"]["historical_data"] = {"format": "pandas"}
935+
yaml_i["spec"]["target_column"] = TARGET_COL.name
936+
yaml_i["spec"]["datetime_column"]["name"] = HISTORICAL_DATETIME_COL.name
937+
yaml_i["spec"]["report_title"] = "Skibidi ADS Skibidi"
938+
yaml_i["spec"]["output_directory"]["url"] = operator_setup
939+
yaml_i["spec"]["generate_explanations_file"] = False
940+
yaml_i["spec"]["generate_forecast_file"] = False
941+
yaml_i["spec"]["generate_metrics_file"] = False
942+
943+
df = pd.concat([HISTORICAL_DATETIME_COL[:15], TARGET_COL[:15]], axis=1)
944+
yaml_i["spec"]["historical_data"]["data"] = df
945+
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
946+
results = operate(operator_config)
947+
files = os.listdir(operator_setup)
948+
assert "report.html" in files, "Failed to generate report"
949+
assert (
950+
"forecast.csv" not in files
951+
), "Generated forecast file, but `generate_forecast_file` was set False"
952+
assert (
953+
"metrics.csv" not in files
954+
), "Generated metrics file, but `generate_metrics_file` was set False"
955+
assert not results.get_forecast().empty
956+
assert not results.get_metrics().empty
957+
958+
924959
if __name__ == "__main__":
925960
pass

0 commit comments

Comments
 (0)