Update faiss_index_bq_dataset.py
1. Update the import statements: Since the code is using Python 3.7, it's better to use relative imports instead of absolute imports. Replace the import statements like from apache_beam.options.pipeline_options import PipelineOptions with from .apache_beam.options.pipeline_options import PipelineOptions (assuming the file is part of a package). 2. Remove unnecessary imports: The code imports the os and urlsplit modules but doesn't use them. You can safely remove those import statements. 3. Handle the case when argv is not provided: The parse_d6w_config function assumes that argv is always provided, but it's not necessary. You can update the function signature to parse_d6w_config(argv=None) to handle the case when argv is not provided. 4. Update the logging configuration: Instead of setting the logging level to logging.INFO directly in the code, you can make it configurable through command-line arguments or environment variables.
This commit is contained in:
parent
fb54d8b549
commit
f4442aef4e
|
@ -1,12 +1,11 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
|
||||
import apache_beam as beam
|
||||
from apache_beam.options.pipeline_options import PipelineOptions
|
||||
from .apache_beam.options.pipeline_options import PipelineOptions
|
||||
import faiss
|
||||
|
||||
|
||||
|
@ -94,8 +93,8 @@ def parse_metric(config):
|
|||
raise Exception(f"Unknown metric: {metric_str}")
|
||||
|
||||
|
||||
def run_pipeline(argv=[]):
|
||||
config = parse_d6w_config(argv)
|
||||
def run_pipeline(argv=[], log_level = logging.INFO):
|
||||
config = parse_d6w_config(argv=None)
|
||||
argv_with_extras = argv
|
||||
if config["gpu"]:
|
||||
argv_with_extras.extend(["--experiments", "use_runner_v2"])
|
||||
|
@ -108,7 +107,7 @@ def run_pipeline(argv=[]):
|
|||
"gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7",
|
||||
]
|
||||
)
|
||||
|
||||
logging.getLogger().setLevel(log_level)
|
||||
options = PipelineOptions(argv_with_extras)
|
||||
output_bucket_name = urlsplit(config["output_location"]).netloc
|
||||
|
||||
|
@ -228,5 +227,10 @@ class MergeAndBuildIndex(beam.CombineFn):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
run_pipeline(sys.argv)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--log_level", dest="log_level", default="INFO", help="Logging level")
|
||||
args, pipeline_args = parser.parse_known_args()
|
||||
|
||||
logging.getLogger().setLevel(args.log_level)
|
||||
run_pipeline(pipeline_args, log_level=args.log_level)
|
||||
|
||||
|
|
Loading…
Reference in New Issue