Skip to content

Commit

Permalink
suppress errors in reagent
Browse files Browse the repository at this point in the history
Differential Revision: D44719105

fbshipit-source-id: 9e73e110d4c3ed858ac9e8944c73404bb6fa6122
  • Loading branch information
generatedunixname89002005307016 authored and facebook-github-bot committed Apr 5, 2023
1 parent 9e1a350 commit d94632c
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 36 deletions.
29 changes: 0 additions & 29 deletions reagent/data/oss_data_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import logging
from typing import List, Optional, Tuple

# pyre-fixme[21]: Could not find `pyspark`.
# pyre-fixme[21]: Could not find `pyspark`.
from pyspark.sql.functions import col, crc32, explode, map_keys, udf
from pyspark.sql.types import (
ArrayType,
Expand Down Expand Up @@ -81,13 +79,10 @@ def hash_mdp_id_and_subsample(df, sample_range: Optional[Tuple[float, float]] =
and sample_range[1] <= 100.0
), f"{sample_range} is invalid."

# pyre-fixme[16]: Module `functions` has no attribute `col`.
df = df.withColumn("mdp_id", crc32(col("mdp_id")))
if sample_range:
lower_bound = sample_range[0] / 100.0 * MAX_UINT32
upper_bound = sample_range[1] / 100.0 * MAX_UINT32
# pyre-fixme[16]: Module `functions` has no attribute `col`.
# pyre-fixme[16]: Module `functions` has no attribute `col`.
df = df.filter((lower_bound <= col("mdp_id")) & (col("mdp_id") <= upper_bound))
return df

Expand Down Expand Up @@ -121,9 +116,7 @@ def sparse2dense(map_col):

sparse2dense_udf = udf(sparse2dense, output_type)
df = df.withColumn(col_name, sparse2dense_udf(col_name))
# pyre-fixme[16]: Module `functions` has no attribute `col`.
df = df.withColumn(f"{col_name}_presence", col(f"{col_name}.presence"))
# pyre-fixme[16]: Module `functions` has no attribute `col`.
df = df.withColumn(col_name, col(f"{col_name}.dense"))
return df

Expand Down Expand Up @@ -193,7 +186,6 @@ def misc_column_preprocessing(df, multi_steps: Optional[int]):
df = df.withColumn("time_diff", next_long_udf("time_diff"))

# assuming use_seq_num_diff_as_time_diff = False for now
# pyre-fixme[16]: Module `functions` has no attribute `col`.
df = df.withColumn("sequence_number", col("sequence_number_ordinal"))

return df
Expand Down Expand Up @@ -302,58 +294,37 @@ def select_relevant_columns(
raise NotImplementedError("currently we don't support include_possible_actions")

select_col_list = [
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("reward").cast(FloatType()),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("state_features").cast(ArrayType(FloatType())),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("state_features_presence").cast(ArrayType(BooleanType())),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("next_state_features").cast(ArrayType(FloatType())),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("next_state_features_presence").cast(ArrayType(BooleanType())),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("not_terminal").cast(BooleanType()),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("action_probability").cast(FloatType()),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("mdp_id").cast(LongType()),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("sequence_number").cast(LongType()),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("step").cast(LongType()),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("time_diff").cast(LongType()),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("metrics").cast(ArrayType(FloatType())),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("metrics_presence").cast(ArrayType(BooleanType())),
]

if discrete_action:
select_col_list += [
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("action").cast(LongType()),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("next_action").cast(LongType()),
]
else:
select_col_list += [
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("action").cast(ArrayType(FloatType())),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("next_action").cast(ArrayType(FloatType())),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("action_presence").cast(ArrayType(BooleanType())),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("next_action_presence").cast(ArrayType(BooleanType())),
]

if include_possible_actions:
select_col_list += [
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("possible_actions_mask").cast(ArrayType(LongType())),
# pyre-fixme[16]: Module `functions` has no attribute `col`.
col("possible_next_actions_mask").cast(ArrayType(LongType())),
]

Expand Down
3 changes: 0 additions & 3 deletions reagent/data/spark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import reagent
from pyspark.sql import SparkSession

# pyre-fixme[21]: Could not find module `pyspark.sql.functions`.
# pyre-fixme[21]: Could not find module `pyspark.sql.functions`.
from pyspark.sql.functions import col


Expand Down Expand Up @@ -72,7 +70,6 @@ def get_table_url(table_name: str) -> str:
spark = get_spark_session()
url = (
spark.sql(f"DESCRIBE FORMATTED {table_name}")
# pyre-fixme[16]: Module `functions` has no attribute `col`.
.filter((col("col_name") == "Location"))
.select("data_type")
.toPandas()
Expand Down
4 changes: 0 additions & 4 deletions reagent/workflow/identify_types_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import reagent.core.types as rlt

# pyre-fixme[21]: Could not find `pyspark`.
# pyre-fixme[21]: Could not find `pyspark`.
from pyspark.sql.functions import col, collect_list, explode
from reagent.data.spark_utils import get_spark_session
from reagent.preprocessing.normalization import (
Expand Down Expand Up @@ -114,7 +112,6 @@ def create_normalization_spec_spark(

# assumes column has a type of map
df = df.select(
# pyre-fixme[16]: Module `functions` has no attribute `col`.
explode(col(column).alias("features")).alias("feature_name", "feature_value")
)

Expand All @@ -129,7 +126,6 @@ def create_normalization_spec_spark(
# perform sampling and collect them
df = df.sampleBy("feature_name", fractions=frac, seed=seed)
df = df.groupBy("feature_name").agg(
# pyre-fixme[16]: Module `functions` has no attribute `collect_list`.
collect_list("feature_value").alias("feature_values")
)
return df
Expand Down

0 comments on commit d94632c

Please sign in to comment.